2024-12-31 13:04:02 -05:00

138 lines
4.3 KiB
Python

from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from functools import lru_cache
from dateutil import tz
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_global_token_rate_limits
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_chat_accesssible_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
if not any_rate_limit_exists():
return
versioned_rate_limit_strategy = fetch_versioned_implementation(
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get())
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)
"""
Global rate limits
"""
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
if global_rate_limits:
global_cutoff_time = _get_cutoff_time(global_rate_limits)
global_usage = _fetch_global_usage(global_cutoff_time, db_session)
if _is_rate_limited(global_rate_limits, global_usage):
raise HTTPException(
status_code=429,
detail="Token budget exceeded for organization. Try again later.",
)
def _fetch_global_usage(
cutoff_time: datetime, db_session: Session
) -> Sequence[tuple[datetime, int]]:
"""
Fetch global token usage within the cutoff time, grouped by minute
"""
result = db_session.execute(
select(
func.date_trunc("minute", ChatMessage.time_sent),
func.sum(ChatMessage.token_count),
)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.filter(
ChatMessage.time_sent >= cutoff_time,
)
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
).all()
return [(row[0], row[1]) for row in result]
"""
Common functions
"""
def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime:
max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits)
return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours)
def _is_rate_limited(
rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]]
) -> bool:
"""
If at least one rate limit is exceeded, return True
"""
for rate_limit in rate_limits:
tokens_used = sum(
u_token_count
for u_date, u_token_count in usage
if u_date
>= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours)
)
if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT:
return True
return False
@lru_cache()
def any_rate_limit_exists() -> bool:
"""Checks if any rate limit exists in the database. Is cached, so that if no rate limits
are setup, we don't have any effect on average query latency."""
logger.debug("Checking for any rate limits...")
with get_session_context_manager() as db_session:
return (
db_session.scalar(
select(TokenRateLimit.id).where(
TokenRateLimit.enabled == True # noqa: E712
)
)
is not None
)