mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 04:40:09 +02:00
138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
from collections.abc import Sequence
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from datetime import timezone
|
|
from functools import lru_cache
|
|
|
|
from dateutil import tz
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import func
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_chat_accesssible_user
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.db.engine import get_session_with_tenant
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.db.models import ChatSession
|
|
from onyx.db.models import TokenRateLimit
|
|
from onyx.db.models import User
|
|
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
TOKEN_BUDGET_UNIT = 1_000
|
|
|
|
|
|
def check_token_rate_limits(
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
) -> None:
|
|
# short circuit if no rate limits are set up
|
|
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
|
if not any_rate_limit_exists():
|
|
return
|
|
|
|
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
|
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
|
)
|
|
return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get())
|
|
|
|
|
|
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
|
_user_is_rate_limited_by_global(tenant_id)
|
|
|
|
|
|
"""
|
|
Global rate limits
|
|
"""
|
|
|
|
|
|
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
|
with get_session_with_tenant(tenant_id) as db_session:
|
|
global_rate_limits = fetch_all_global_token_rate_limits(
|
|
db_session=db_session, enabled_only=True, ordered=False
|
|
)
|
|
|
|
if global_rate_limits:
|
|
global_cutoff_time = _get_cutoff_time(global_rate_limits)
|
|
global_usage = _fetch_global_usage(global_cutoff_time, db_session)
|
|
|
|
if _is_rate_limited(global_rate_limits, global_usage):
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail="Token budget exceeded for organization. Try again later.",
|
|
)
|
|
|
|
|
|
def _fetch_global_usage(
|
|
cutoff_time: datetime, db_session: Session
|
|
) -> Sequence[tuple[datetime, int]]:
|
|
"""
|
|
Fetch global token usage within the cutoff time, grouped by minute
|
|
"""
|
|
result = db_session.execute(
|
|
select(
|
|
func.date_trunc("minute", ChatMessage.time_sent),
|
|
func.sum(ChatMessage.token_count),
|
|
)
|
|
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
|
.filter(
|
|
ChatMessage.time_sent >= cutoff_time,
|
|
)
|
|
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
|
|
).all()
|
|
|
|
return [(row[0], row[1]) for row in result]
|
|
|
|
|
|
"""
|
|
Common functions
|
|
"""
|
|
|
|
|
|
def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime:
|
|
max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits)
|
|
return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours)
|
|
|
|
|
|
def _is_rate_limited(
|
|
rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]]
|
|
) -> bool:
|
|
"""
|
|
If at least one rate limit is exceeded, return True
|
|
"""
|
|
for rate_limit in rate_limits:
|
|
tokens_used = sum(
|
|
u_token_count
|
|
for u_date, u_token_count in usage
|
|
if u_date
|
|
>= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours)
|
|
)
|
|
|
|
if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@lru_cache()
|
|
def any_rate_limit_exists() -> bool:
|
|
"""Checks if any rate limit exists in the database. Is cached, so that if no rate limits
|
|
are setup, we don't have any effect on average query latency."""
|
|
logger.debug("Checking for any rate limits...")
|
|
with get_session_context_manager() as db_session:
|
|
return (
|
|
db_session.scalar(
|
|
select(TokenRateLimit.id).where(
|
|
TokenRateLimit.enabled == True # noqa: E712
|
|
)
|
|
)
|
|
is not None
|
|
)
|