mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-07 11:20:18 +02:00
Fix user count (#4677)
* Fix user count * Add helper + fix async function as well * fix mypy * Address RK comment
This commit is contained in:
@@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.models import ID
|
||||
@@ -9,11 +10,11 @@ from fastapi_users.models import UP
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import get_api_key_email_pattern
|
||||
from onyx.db.engine import get_async_session
|
||||
@@ -25,6 +26,8 @@ from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=tuple[Any, ...])
|
||||
|
||||
|
||||
def get_default_admin_user_emails() -> list[str]:
|
||||
"""Returns a list of emails who should default to Admin role.
|
||||
@@ -37,31 +40,44 @@ def get_default_admin_user_emails() -> list[str]:
|
||||
return get_default_admin_user_emails_fn()
|
||||
|
||||
|
||||
def get_total_users_count(db_session: Session) -> int:
|
||||
def _add_live_user_count_where_clause(
|
||||
select_stmt: Select[T],
|
||||
only_admin_users: bool,
|
||||
) -> Select[T]:
|
||||
"""
|
||||
Returns the total number of users in the system.
|
||||
This is the sum of users and invited users.
|
||||
Builds a SQL column expression that can be used to filter out
|
||||
users who should not be included in the live user count.
|
||||
"""
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
.count()
|
||||
select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
|
||||
if only_admin_users:
|
||||
return select_stmt.where(User.role == UserRole.ADMIN)
|
||||
|
||||
return select_stmt.where(
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
invited_users = len(get_invited_users())
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
def get_live_users_count(db_session: Session) -> int:
|
||||
"""
|
||||
Returns the number of users in the system.
|
||||
This does NOT include invited users, "users" pulled in
|
||||
from external connectors, or API keys.
|
||||
"""
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
select_stmt = select(count_stmt)
|
||||
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
|
||||
user_count = db_session.scalar(select_stmt_w_filters)
|
||||
if user_count is None:
|
||||
raise RuntimeError("Was not able to fetch the user count.")
|
||||
return user_count
|
||||
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_context_manager() as session:
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
stmt = select(count_stmt)
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
|
||||
user_count = await session.scalar(stmt_w_filters)
|
||||
if user_count is None:
|
||||
raise RuntimeError("Was not able to fetch the user count.")
|
||||
return user_count
|
||||
|
@@ -44,7 +44,7 @@ from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_total_users_count
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import User
|
||||
@@ -343,7 +343,7 @@ def bulk_invite_users(
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
|
||||
return number_of_invited_users
|
||||
except Exception as e:
|
||||
@@ -379,7 +379,7 @@ def remove_invited_user(
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Request to update number of seats taken in control plane failed. "
|
||||
|
Reference in New Issue
Block a user