mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 01:40:32 +02:00
125 lines
4.0 KiB
Python
125 lines
4.0 KiB
Python
from collections.abc import Sequence
|
|
|
|
from sqlalchemy import exists
|
|
from sqlalchemy import Row
|
|
from sqlalchemy import Select
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import aliased
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.constants import TokenRateLimitScope
|
|
from onyx.db.models import TokenRateLimit
|
|
from onyx.db.models import TokenRateLimit__UserGroup
|
|
from onyx.db.models import User
|
|
from onyx.db.models import User__UserGroup
|
|
from onyx.db.models import UserGroup
|
|
from onyx.db.models import UserRole
|
|
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
|
|
|
|
|
def _add_user_filters(
|
|
stmt: Select, user: User | None, get_editable: bool = True
|
|
) -> Select:
|
|
# If user is None, assume the user is an admin or auth is disabled
|
|
if user is None or user.role == UserRole.ADMIN:
|
|
return stmt
|
|
|
|
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
|
|
User__UG = aliased(User__UserGroup)
|
|
|
|
"""
|
|
Here we select token_rate_limits by relation:
|
|
User -> User__UserGroup -> TokenRateLimit__UserGroup ->
|
|
TokenRateLimit
|
|
"""
|
|
stmt = stmt.outerjoin(TRLimit_UG).outerjoin(
|
|
User__UG,
|
|
User__UG.user_group_id == TRLimit_UG.user_group_id,
|
|
)
|
|
|
|
"""
|
|
Filter token_rate_limits by:
|
|
- if the user is in the user_group that owns the token_rate_limit
|
|
- if the user is not a global_curator, they must also have a curator relationship
|
|
to the user_group
|
|
- if editing is being done, we also filter out token_rate_limits that are owned by groups
|
|
that the user isn't a curator for
|
|
- if we are not editing, we show all token_rate_limits in the groups the user curates
|
|
"""
|
|
where_clause = User__UG.user_id == user.id
|
|
if user.role == UserRole.CURATOR and get_editable:
|
|
where_clause &= User__UG.is_curator == True # noqa: E712
|
|
if get_editable:
|
|
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
|
|
if user.role == UserRole.CURATOR:
|
|
user_groups = user_groups.where(
|
|
User__UserGroup.is_curator == True # noqa: E712
|
|
)
|
|
where_clause &= (
|
|
~exists()
|
|
.where(TRLimit_UG.rate_limit_id == TokenRateLimit.id)
|
|
.where(~TRLimit_UG.user_group_id.in_(user_groups))
|
|
.correlate(TokenRateLimit)
|
|
)
|
|
|
|
return stmt.where(where_clause)
|
|
|
|
|
|
def fetch_all_user_group_token_rate_limits_by_group(
|
|
db_session: Session,
|
|
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
|
|
query = (
|
|
select(TokenRateLimit, UserGroup.name)
|
|
.join(
|
|
TokenRateLimit__UserGroup,
|
|
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
|
)
|
|
.join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id)
|
|
)
|
|
|
|
return db_session.execute(query).all()
|
|
|
|
|
|
def insert_user_group_token_rate_limit(
|
|
db_session: Session,
|
|
token_rate_limit_settings: TokenRateLimitArgs,
|
|
group_id: int,
|
|
) -> TokenRateLimit:
|
|
token_limit = TokenRateLimit(
|
|
enabled=token_rate_limit_settings.enabled,
|
|
token_budget=token_rate_limit_settings.token_budget,
|
|
period_hours=token_rate_limit_settings.period_hours,
|
|
scope=TokenRateLimitScope.USER_GROUP,
|
|
)
|
|
db_session.add(token_limit)
|
|
db_session.flush()
|
|
|
|
rate_limit = TokenRateLimit__UserGroup(
|
|
rate_limit_id=token_limit.id, user_group_id=group_id
|
|
)
|
|
db_session.add(rate_limit)
|
|
db_session.commit()
|
|
|
|
return token_limit
|
|
|
|
|
|
def fetch_user_group_token_rate_limits(
|
|
db_session: Session,
|
|
group_id: int,
|
|
user: User | None = None,
|
|
enabled_only: bool = False,
|
|
ordered: bool = True,
|
|
get_editable: bool = True,
|
|
) -> Sequence[TokenRateLimit]:
|
|
stmt = select(TokenRateLimit)
|
|
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
|
stmt = _add_user_filters(stmt, user, get_editable)
|
|
|
|
if enabled_only:
|
|
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
|
|
|
|
if ordered:
|
|
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
|
|
|
|
return db_session.scalars(stmt).all()
|