danswer/backend/ee/onyx/db/token_limit.py
hagen-danswer b1957737f2
refactored _add_user_filter usage (#3674)
* refactored db.connector_credential_pair

* Rerfactored the db.credentials user filtering

* the restr
2025-01-14 23:35:52 +00:00

133 lines
4.3 KiB
Python

from collections.abc import Sequence
from sqlalchemy import exists
from sqlalchemy import Row
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
"""
Here we select token_rate_limits by relation:
User -> User__UserGroup -> TokenRateLimit__UserGroup ->
TokenRateLimit
"""
stmt = stmt.outerjoin(TRLimit_UG).outerjoin(
User__UG,
User__UG.user_group_id == TRLimit_UG.user_group_id,
)
"""
Filter token_rate_limits by:
- if the user is in the user_group that owns the token_rate_limit
- if the user is not a global_curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out token_rate_limits that are owned by groups
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
# If user is None, this is an anonymous user and we should only show public token_rate_limits
if user is None:
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
if user.role == UserRole.CURATOR:
user_groups = user_groups.where(
User__UserGroup.is_curator == True # noqa: E712
)
where_clause &= (
~exists()
.where(TRLimit_UG.rate_limit_id == TokenRateLimit.id)
.where(~TRLimit_UG.user_group_id.in_(user_groups))
.correlate(TokenRateLimit)
)
return stmt.where(where_clause)
def fetch_all_user_group_token_rate_limits_by_group(
db_session: Session,
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
query = (
select(TokenRateLimit, UserGroup.name)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id)
)
return db_session.execute(query).all()
def insert_user_group_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
group_id: int,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER_GROUP,
)
db_session.add(token_limit)
db_session.flush()
rate_limit = TokenRateLimit__UserGroup(
rate_limit_id=token_limit.id, user_group_id=group_id
)
db_session.add(rate_limit)
db_session.commit()
return token_limit
def fetch_user_group_token_rate_limits_for_user(
db_session: Session,
group_id: int,
user: User | None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(stmt).all()