danswer/backend/ee/onyx/db/token_limit.py
2024-12-13 09:56:10 -08:00

125 lines
4.0 KiB
Python

from collections.abc import Sequence
from sqlalchemy import exists
from sqlalchemy import Row
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return stmt
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
"""
Here we select token_rate_limits by relation:
User -> User__UserGroup -> TokenRateLimit__UserGroup ->
TokenRateLimit
"""
stmt = stmt.outerjoin(TRLimit_UG).outerjoin(
User__UG,
User__UG.user_group_id == TRLimit_UG.user_group_id,
)
"""
Filter token_rate_limits by:
- if the user is in the user_group that owns the token_rate_limit
- if the user is not a global_curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out token_rate_limits that are owned by groups
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
if user.role == UserRole.CURATOR:
user_groups = user_groups.where(
User__UserGroup.is_curator == True # noqa: E712
)
where_clause &= (
~exists()
.where(TRLimit_UG.rate_limit_id == TokenRateLimit.id)
.where(~TRLimit_UG.user_group_id.in_(user_groups))
.correlate(TokenRateLimit)
)
return stmt.where(where_clause)
def fetch_all_user_group_token_rate_limits_by_group(
db_session: Session,
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
query = (
select(TokenRateLimit, UserGroup.name)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id)
)
return db_session.execute(query).all()
def insert_user_group_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
group_id: int,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER_GROUP,
)
db_session.add(token_limit)
db_session.flush()
rate_limit = TokenRateLimit__UserGroup(
rate_limit_id=token_limit.id, user_group_id=group_id
)
db_session.add(rate_limit)
db_session.commit()
return token_limit
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(stmt).all()