Fix user count (#4677)

* Fix user count

* Add helper + fix async function as well

* fix mypy

* Address RK comment
This commit is contained in:
Chris Weaver
2025-05-08 17:19:40 -07:00
committed by GitHub
parent 1dd98a87cc
commit 91831f4d07
2 changed files with 36 additions and 20 deletions

View File

@@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
from collections.abc import Callable
from typing import Any
from typing import Dict
from typing import TypeVar
from fastapi import Depends
from fastapi_users.models import ID
@@ -9,11 +10,11 @@ from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import get_api_key_email_pattern
from onyx.db.engine import get_async_session
@@ -25,6 +26,8 @@ from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
T = TypeVar("T", bound=tuple[Any, ...])
def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
@@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users_count(db_session: Session) -> int:
def _add_live_user_count_where_clause(
select_stmt: Select[T],
only_admin_users: bool,
) -> Select[T]:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
Builds a SQL column expression that can be used to filter out
users who should not be included in the live user count.
"""
user_count = (
db_session.query(User)
.filter(
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
User.role != UserRole.EXT_PERM_USER,
)
.count()
select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
if only_admin_users:
return select_stmt.where(User.role == UserRole.ADMIN)
return select_stmt.where(
User.role != UserRole.EXT_PERM_USER,
)
invited_users = len(get_invited_users())
return user_count + invited_users
def get_live_users_count(db_session: Session) -> int:
"""
Returns the number of users in the system.
This does NOT include invited users, "users" pulled in
from external connectors, or API keys.
"""
count_stmt = func.count(User.id) # type: ignore
select_stmt = select(count_stmt)
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
user_count = db_session.scalar(select_stmt_w_filters)
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_context_manager() as session:
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
user_count = await session.scalar(stmt_w_filters)
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count

View File

@@ -44,7 +44,7 @@ from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.constants import AuthType
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_total_users_count
from onyx.db.auth import get_live_users_count
from onyx.db.engine import get_session
from onyx.db.models import AccessToken
from onyx.db.models import User
@@ -343,7 +343,7 @@ def bulk_invite_users(
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
)(tenant_id, get_live_users_count(db_session))
return number_of_invited_users
except Exception as e:
@@ -379,7 +379,7 @@ def remove_invited_user(
if MULTI_TENANT and not DEV_MODE:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
)(tenant_id, get_live_users_count(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "