diff --git a/backend/onyx/db/auth.py b/backend/onyx/db/auth.py index 974d1b83ee46..c9e5c57f7971 100644 --- a/backend/onyx/db/auth.py +++ b/backend/onyx/db/auth.py @@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator from collections.abc import Callable from typing import Any from typing import Dict +from typing import TypeVar from fastapi import Depends from fastapi_users.models import ID @@ -9,11 +10,11 @@ from fastapi_users.models import UP from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase from sqlalchemy import func +from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import Session -from onyx.auth.invited_users import get_invited_users from onyx.auth.schemas import UserRole from onyx.db.api_key import get_api_key_email_pattern from onyx.db.engine import get_async_session @@ -25,6 +26,8 @@ from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) +T = TypeVar("T", bound=tuple[Any, ...]) + def get_default_admin_user_emails() -> list[str]: """Returns a list of emails who should default to Admin role. @@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]: return get_default_admin_user_emails_fn() -def get_total_users_count(db_session: Session) -> int: +def _add_live_user_count_where_clause( + select_stmt: Select[T], + only_admin_users: bool, +) -> Select[T]: """ - Returns the total number of users in the system. - This is the sum of users and invited users. + Builds a SQL column expression that can be used to filter out + users who should not be included in the live user count. """ - user_count = ( - db_session.query(User) - .filter( - ~User.email.endswith(get_api_key_email_pattern()), # type: ignore - User.role != UserRole.EXT_PERM_USER, - ) - .count() + select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore + if only_admin_users: + return select_stmt.where(User.role == UserRole.ADMIN) + + return select_stmt.where( + User.role != UserRole.EXT_PERM_USER, ) - invited_users = len(get_invited_users()) - return user_count + invited_users + + +def get_live_users_count(db_session: Session) -> int: + """ + Returns the number of users in the system. + This does NOT include invited users, "users" pulled in + from external connectors, or API keys. + """ + count_stmt = func.count(User.id) # type: ignore + select_stmt = select(count_stmt) + select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False) + user_count = db_session.scalar(select_stmt_w_filters) + if user_count is None: + raise RuntimeError("Was not able to fetch the user count.") + return user_count async def get_user_count(only_admin_users: bool = False) -> int: async with get_async_session_context_manager() as session: count_stmt = func.count(User.id) # type: ignore stmt = select(count_stmt) - if only_admin_users: - stmt = stmt.where(User.role == UserRole.ADMIN) - result = await session.execute(stmt) - user_count = result.scalar() + stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users) + user_count = await session.scalar(stmt_w_filters) if user_count is None: raise RuntimeError("Was not able to fetch the user count.") return user_count diff --git a/backend/onyx/server/manage/users.py b/backend/onyx/server/manage/users.py index a179128dc6a1..9c93b4e93cef 100644 --- a/backend/onyx/server/manage/users.py +++ b/backend/onyx/server/manage/users.py @@ -44,7 +44,7 @@ from onyx.configs.app_configs import VALID_EMAIL_DOMAINS from onyx.configs.constants import AuthType from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.db.api_key import is_api_key_email_address -from onyx.db.auth import get_total_users_count +from onyx.db.auth import get_live_users_count from onyx.db.engine import get_session from onyx.db.models import AccessToken from onyx.db.models import User @@ -343,7 +343,7 @@ def bulk_invite_users( logger.info("Registering tenant users") fetch_ee_implementation_or_noop( "onyx.server.tenants.billing", "register_tenant_users", None - )(tenant_id, get_total_users_count(db_session)) + )(tenant_id, get_live_users_count(db_session)) return number_of_invited_users except Exception as e: @@ -379,7 +379,7 @@ def remove_invited_user( if MULTI_TENANT and not DEV_MODE: fetch_ee_implementation_or_noop( "onyx.server.tenants.billing", "register_tenant_users", None - )(tenant_id, get_total_users_count(db_session)) + )(tenant_id, get_live_users_count(db_session)) except Exception: logger.error( "Request to update number of seats taken in control plane failed. "