Token Rate Limiting

WIP

Cleanup 🧹

Remove existing rate limiting logic

Cleanup 🧼

Undo nit

Cleanup 🧽

Move db constants (avoids circular import)

WIP

WIP

Cleanup

Lint

Resolve alembic conflict

Fix mypy

Add backfill to migration

Update comment

Make unauthenticated users still adhere to global limits

Use Depends

Remove enum from table

Update migration error handling + deletion

Address verbal feedback, cleanup urls, minor nits
This commit is contained in:
Alan Hagedorn 2024-04-14 18:53:38 -07:00 committed by Chris Weaver
parent 7a408749cf
commit d7a704c0d9
24 changed files with 1497 additions and 298 deletions

View File

@ -41,10 +41,6 @@ DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
TOKEN_BUDGET = "token_budget"
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
ENABLE_TOKEN_BUDGET = "enable_token_budget"
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
# For chunking/processing chunks
TITLE_SEPARATOR = "\n\r\n"

View File

@ -83,6 +83,9 @@ from danswer.server.settings.api import basic_router as settings_router
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@ -281,6 +284,9 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@ -1,23 +1,16 @@
import json
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
from fastapi import APIRouter
from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
from danswer.configs.constants import TOKEN_BUDGET
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_session
@ -193,41 +186,3 @@ def create_deletion_attempt_for_connector_id(
file_store = get_default_file_store(db_session)
for file_name in connector.connector_specific_config["file_locations"]:
file_store.delete_file(file_name)
@router.get("/admin/token-budget-settings")
def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict:
if not TOKEN_BUDGET_GLOBALLY_ENABLED:
raise HTTPException(
status_code=400, detail="Token budget is not enabled in the application."
)
try:
settings_json = cast(
str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS)
)
settings = json.loads(settings_json)
return settings
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Token budget settings not found.")
@router.put("/admin/token-budget-settings")
def update_token_budget_settings(
_: User = Depends(current_admin_user),
enable_token_budget: bool = Body(..., embed=True),
token_budget: int = Body(..., ge=0, embed=True), # Ensure non-negative
token_budget_time_period: int = Body(..., ge=1, embed=True), # Ensure positive
) -> dict[str, str]:
# Prepare the settings as a JSON string
settings_json = json.dumps(
{
ENABLE_TOKEN_BUDGET: enable_token_budget,
TOKEN_BUDGET: token_budget,
TOKEN_BUDGET_TIME_PERIOD: token_budget_time_period,
}
)
# Store the settings in the dynamic config store
get_dynamic_config_store().store(TOKEN_BUDGET_SETTINGS, settings_json)
return {"message": "Token budget settings updated successfully."}

View File

@ -64,6 +64,7 @@ from danswer.server.query_and_chat.models import PromptOverride
from danswer.server.query_and_chat.models import RenameChatSessionResponse
from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -275,6 +276,7 @@ def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
- Sending a new message in the session

View File

@ -29,7 +29,7 @@ from danswer.server.query_and_chat.models import QueryValidationResponse
from danswer.server.query_and_chat.models import SimpleQueryRequest
from danswer.server.query_and_chat.models import SourceTag
from danswer.server.query_and_chat.models import TagResponse
from danswer.server.query_and_chat.token_budget import check_token_budget
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -150,7 +150,7 @@ def stream_query_validation(
def get_answer_with_quote(
query_request: DirectQARequest,
user: User = Depends(current_user),
_: bool = Depends(check_token_budget),
_: None = Depends(check_token_rate_limits),
) -> StreamingResponse:
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")

View File

@ -1,79 +0,0 @@
import json
from datetime import datetime
from datetime import timedelta
from typing import cast
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
from danswer.configs.constants import TOKEN_BUDGET
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.dynamic_configs.factory import get_dynamic_config_store
BUDGET_LIMIT_DEFAULT = -1 # Default to no limit
TIME_PERIOD_HOURS_DEFAULT = 12
def is_under_token_budget(db_session: Session) -> bool:
try:
settings_json = cast(
str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS)
)
except Exception:
return True
settings = json.loads(settings_json)
is_enabled = settings.get(ENABLE_TOKEN_BUDGET, False)
if not is_enabled:
return True
budget_limit = settings.get(TOKEN_BUDGET, -1)
if budget_limit < 0:
return True
period_hours = settings.get(TOKEN_BUDGET_TIME_PERIOD, TIME_PERIOD_HOURS_DEFAULT)
period_start_time = datetime.now() - timedelta(hours=period_hours)
# Fetch the sum of all tokens used within the period
token_sum = (
db_session.query(func.sum(ChatMessage.token_count))
.filter(ChatMessage.time_sent >= period_start_time)
.scalar()
or 0
)
print(
"token_sum:",
token_sum,
"budget_limit:",
budget_limit,
"period_hours:",
period_hours,
"period_start_time:",
period_start_time,
)
return token_sum < (
budget_limit * 1000
) # Budget limit is expressed in thousands of tokens
def check_token_budget() -> None:
if not TOKEN_BUDGET_GLOBALLY_ENABLED:
return None
with get_session_context_manager() as db_session:
# Perform the token budget check here, possibly using `user` and `db_session` for database access if needed
if not is_under_token_budget(db_session):
raise HTTPException(
status_code=429, detail="Sorry, token budget exceeded. Try again later."
)

View File

@ -0,0 +1,135 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from functools import lru_cache
from dateutil import tz
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
logger = setup_logger()
TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
if not any_rate_limit_exists():
return
versioned_rate_limit_strategy = fetch_versioned_implementation(
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user)
def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()
"""
Global rate limits
"""
def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
if global_rate_limits:
global_cutoff_time = _get_cutoff_time(global_rate_limits)
global_usage = _fetch_global_usage(global_cutoff_time, db_session)
if _is_rate_limited(global_rate_limits, global_usage):
raise HTTPException(
status_code=429,
detail="Token budget exceeded for organization. Try again later.",
)
def _fetch_global_usage(
cutoff_time: datetime, db_session: Session
) -> Sequence[tuple[datetime, int]]:
"""
Fetch global token usage within the cutoff time, grouped by minute
"""
result = db_session.execute(
select(
func.date_trunc("minute", ChatMessage.time_sent),
func.sum(ChatMessage.token_count),
)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.filter(
ChatMessage.time_sent >= cutoff_time,
)
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
).all()
return [(row[0], row[1]) for row in result]
"""
Common functions
"""
def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime:
max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits)
return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours)
def _is_rate_limited(
rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]]
) -> bool:
"""
If at least one rate limit is exceeded, return True
"""
for rate_limit in rate_limits:
tokens_used = sum(
u_token_count
for u_date, u_token_count in usage
if u_date
>= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours)
)
if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT:
return True
return False
@lru_cache()
def any_rate_limit_exists() -> bool:
"""Checks if any rate limit exists in the database. Is cached, so that if no rate limits
are setup, we don't have any effect on average query latency."""
logger.info("Checking for any rate limits...")
with get_session_context_manager() as db_session:
return (
db_session.scalar(
select(TokenRateLimit.id).where(
TokenRateLimit.enabled == True # noqa: E712
)
)
is not None
)

View File

@ -0,0 +1,79 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import delete_token_rate_limit
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from ee.danswer.db.token_limit import insert_global_token_rate_limit
from ee.danswer.db.token_limit import update_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")
"""
Global Token Limit Settings
"""
@router.get("/global")
def get_global_token_limit_settings(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_all_global_token_rate_limits(db_session)
]
@router.post("/global")
def create_global_token_limit_settings(
token_limit_settings: TokenRateLimitArgs,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
rate_limit_display = TokenRateLimitDisplay.from_db(
insert_global_token_rate_limit(db_session, token_limit_settings)
)
# clear cache in case this was the first rate limit created
any_rate_limit_exists.cache_clear()
return rate_limit_display
"""
General Token Limit Settings
"""
@router.put("/rate-limit/{token_rate_limit_id}")
def update_token_limit_settings(
token_rate_limit_id: int,
token_limit_settings: TokenRateLimitArgs,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
return TokenRateLimitDisplay.from_db(
update_token_rate_limit(
db_session=db_session,
token_rate_limit_id=token_rate_limit_id,
token_rate_limit_settings=token_limit_settings,
)
)
@router.delete("/rate-limit/{token_rate_limit_id}")
def delete_token_limit_settings(
token_rate_limit_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
return delete_token_rate_limit(
db_session=db_session,
token_rate_limit_id=token_rate_limit_id,
)

View File

@ -0,0 +1,25 @@
from pydantic import BaseModel
from danswer.db.models import TokenRateLimit
class TokenRateLimitArgs(BaseModel):
enabled: bool
token_budget: int
period_hours: int
class TokenRateLimitDisplay(BaseModel):
token_id: int
enabled: bool
token_budget: int
period_hours: int
@classmethod
def from_db(cls, token_rate_limit: TokenRateLimit) -> "TokenRateLimitDisplay":
return cls(
token_id=token_rate_limit.id,
enabled=token_rate_limit.enabled,
token_budget=token_rate_limit.token_budget,
period_hours=token_rate_limit.period_hours,
)

View File

@ -0,0 +1,176 @@
from collections.abc import Sequence
from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def fetch_all_user_group_token_rate_limits(
db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True
) -> Sequence[TokenRateLimit]:
query = (
select(TokenRateLimit)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.where(
TokenRateLimit__UserGroup.user_group_id == group_id,
TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP,
)
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def fetch_all_user_group_token_rate_limits_by_group(
db_session: Session,
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
query = (
select(TokenRateLimit, UserGroup.name)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id)
)
return db_session.execute(query).all()
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_user_group_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
group_id: int,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER_GROUP,
)
db_session.add(token_limit)
db_session.flush()
rate_limit = TokenRateLimit__UserGroup(
rate_limit_id=token_limit.id, user_group_id=group_id
)
db_session.add(rate_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
@ -241,6 +242,21 @@ def update_user_group(
return db_user_group
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
@ -255,6 +271,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()

View File

@ -34,6 +34,9 @@ from ee.danswer.server.query_and_chat.query_backend import (
)
from ee.danswer.server.query_history.api import router as query_history_router
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from ee.danswer.server.user_group.api import router as user_group_router
logger = setup_logger()
@ -85,6 +88,10 @@ def get_ee_application() -> FastAPI:
include_router_with_global_prefix_prepended(
application, enterprise_settings_admin_router
)
# Token rate limit settings
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
# Ensure all routes have auth enabled or are explicitly marked as public

View File

@ -0,0 +1,184 @@
from collections import defaultdict
from collections.abc import Sequence
from datetime import datetime
from itertools import groupby
from typing import Dict
from typing import List
from typing import Tuple
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.server.query_and_chat.token_limit import _get_cutoff_time
from danswer.server.query_and_chat.token_limit import _is_rate_limited
from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
def _check_token_rate_limits(user: User | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
]
)
"""
User rate limits
"""
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
if user_rate_limits:
user_cutoff_time = _get_cutoff_time(user_rate_limits)
user_usage = _fetch_user_usage(user_id, user_cutoff_time, db_session)
if _is_rate_limited(user_rate_limits, user_usage):
raise HTTPException(
status_code=429,
detail="Token budget exceeded for user. Try again later.",
)
def _fetch_user_usage(
user_id: UUID, cutoff_time: datetime, db_session: Session
) -> Sequence[tuple[datetime, int]]:
"""
Fetch user usage within the cutoff time, grouped by minute
"""
result = db_session.execute(
select(
func.date_trunc("minute", ChatMessage.time_sent),
func.sum(ChatMessage.token_count),
)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(ChatSession.user_id == user_id, ChatMessage.time_sent >= cutoff_time)
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
).all()
return [(row[0], row[1]) for row in result]
"""
User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_context_manager() as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:
# Group cutoff time is the same for all groups.
# This could be optimized to only fetch the maximum cutoff time for
# a specific group, but seems unnecessary for now.
group_cutoff_time = _get_cutoff_time(
[e for sublist in group_rate_limits.values() for e in sublist]
)
user_group_ids = list(group_rate_limits.keys())
group_usage = _fetch_user_group_usage(
user_group_ids, group_cutoff_time, db_session
)
has_at_least_one_untriggered_limit = False
for user_group_id, rate_limits in group_rate_limits.items():
usage = group_usage.get(user_group_id, [])
if not _is_rate_limited(rate_limits, usage):
has_at_least_one_untriggered_limit = True
break
if not has_at_least_one_untriggered_limit:
raise HTTPException(
status_code=429,
detail="Token budget exceeded for user's groups. Try again later.",
)
def _fetch_all_user_group_rate_limits(
user_id: UUID, db_session: Session
) -> Dict[int, List[TokenRateLimit]]:
group_limits = (
select(TokenRateLimit, User__UserGroup.user_group_id)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.join(
UserGroup,
UserGroup.id == TokenRateLimit__UserGroup.user_group_id,
)
.join(
User__UserGroup,
User__UserGroup.user_group_id == UserGroup.id,
)
.where(
User__UserGroup.user_id == user_id,
TokenRateLimit.enabled.is_(True),
)
)
raw_rate_limits = db_session.execute(group_limits).all()
group_rate_limits = defaultdict(list)
for rate_limit, user_group_id in raw_rate_limits:
group_rate_limits[user_group_id].append(rate_limit)
return group_rate_limits
def _fetch_user_group_usage(
user_group_ids: list[int], cutoff_time: datetime, db_session: Session
) -> dict[int, list[Tuple[datetime, int]]]:
"""
Fetch user group usage within the cutoff time, grouped by minute
"""
user_group_usage = db_session.execute(
select(
func.sum(ChatMessage.token_count),
func.date_trunc("minute", ChatMessage.time_sent),
UserGroup.id,
)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.join(User__UserGroup, User__UserGroup.user_id == ChatSession.user_id)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.filter(UserGroup.id.in_(user_group_ids), ChatMessage.time_sent >= cutoff_time)
.group_by(func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id)
).all()
return {
user_group_id: [(usage, time_sent) for time_sent, usage, _ in group_usage]
for user_group_id, group_usage in groupby(
user_group_usage, key=lambda row: row[2]
)
}

View File

@ -0,0 +1,105 @@
from collections import defaultdict
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
from ee.danswer.db.token_limit import insert_user_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")
"""
Group Token Limit Settings
"""
@router.get("/user-groups")
def get_all_group_token_limit_settings(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, list[TokenRateLimitDisplay]]:
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
db_session
)
token_rate_limits_by_group = defaultdict(list)
for token_rate_limit, group_name in user_groups_to_token_rate_limits:
token_rate_limits_by_group[group_name].append(
TokenRateLimitDisplay.from_db(token_rate_limit)
)
return dict(token_rate_limits_by_group)
@router.get("/user-group/{group_id}")
def get_group_token_limit_settings(
group_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_all_user_group_token_rate_limits(
db_session, group_id
)
]
@router.post("/user-group/{group_id}")
def create_group_token_limit_settings(
group_id: int,
token_limit_settings: TokenRateLimitArgs,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
rate_limit_display = TokenRateLimitDisplay.from_db(
insert_user_group_token_rate_limit(
db_session=db_session,
token_rate_limit_settings=token_limit_settings,
group_id=group_id,
)
)
# clear cache in case this was the first rate limit created
any_rate_limit_exists.cache_clear()
return rate_limit_display
"""
User Token Limit Settings
"""
@router.get("/users")
def get_user_token_limit_settings(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_all_user_token_rate_limits(db_session)
]
@router.post("/users")
def create_user_token_limit_settings(
token_limit_settings: TokenRateLimitArgs,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TokenRateLimitDisplay:
rate_limit_display = TokenRateLimitDisplay.from_db(
insert_user_token_rate_limit(db_session, token_limit_settings)
)
# clear cache in case this was the first rate limit created
any_rate_limit_exists.cache_clear()
return rate_limit_display

View File

@ -48,6 +48,11 @@ const nextConfig = {
source: "/admin/performance/custom-analytics",
destination: "/ee/admin/performance/custom-analytics",
},
// token rate limits
{
source: "/admin/token-rate-limits",
destination: "/ee/admin/token-rate-limits",
},
]
: [];

View File

@ -1,176 +1,9 @@
"use client";
import { Form, Formik } from "formik";
import { useEffect, useState } from "react";
import { AdminPageTitle } from "@/components/admin/Title";
import {
BooleanFormField,
SectionHeader,
TextFormField,
} from "@/components/admin/connectors/Field";
import { Popup } from "@/components/admin/connectors/Popup";
import { Button, Divider, Text } from "@tremor/react";
import { FiCpu } from "react-icons/fi";
import { LLMConfiguration } from "./LLMConfiguration";
const LLMOptions = () => {
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
const [tokenBudgetGloballyEnabled, setTokenBudgetGloballyEnabled] =
useState(false);
const [initialValues, setInitialValues] = useState({
enable_token_budget: false,
token_budget: "",
token_budget_time_period: "",
});
const fetchConfig = async () => {
const response = await fetch("/api/manage/admin/token-budget-settings");
if (response.ok) {
const config = await response.json();
// Assuming the config object directly matches the structure needed for initialValues
setInitialValues({
enable_token_budget: config.enable_token_budget || false,
token_budget: config.token_budget || "",
token_budget_time_period: config.token_budget_time_period || "",
});
setTokenBudgetGloballyEnabled(true);
} else {
// Handle error or provide fallback values
setPopup({
message: "Failed to load current LLM options.",
type: "error",
});
}
};
// Fetch current config when the component mounts
useEffect(() => {
fetchConfig();
}, []);
if (!tokenBudgetGloballyEnabled) {
return null;
}
return (
<>
{popup && <Popup message={popup.message} type={popup.type} />}
<Formik
enableReinitialize={true}
initialValues={initialValues}
onSubmit={async (values) => {
const response = await fetch(
"/api/manage/admin/token-budget-settings",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(values),
}
);
if (response.ok) {
setPopup({
message: "Updated LLM Options",
type: "success",
});
await fetchConfig();
} else {
const body = await response.json();
if (body.detail) {
setPopup({ message: body.detail, type: "error" });
} else {
setPopup({
message: "Unable to update LLM options.",
type: "error",
});
}
setTimeout(() => {
setPopup(null);
}, 4000);
}
}}
>
{({ isSubmitting, values, setFieldValue }) => {
return (
<Form>
<Divider />
<>
<SectionHeader>Token Budget</SectionHeader>
<Text>
Set a maximum token use per time period. If the token budget
is exceeded, Danswer will not be able to respond to queries
until the next time period.
</Text>
<br />
<BooleanFormField
name="enable_token_budget"
label="Enable Token Budget"
subtext="If enabled, Danswer will be limited to the token budget specified below."
onChange={(e) => {
setFieldValue("enable_token_budget", e.target.checked);
}}
/>
{values.enable_token_budget && (
<>
<TextFormField
name="token_budget"
label="Token Budget"
subtext={
<div>
How many tokens (in thousands) can be used per time
period? If unspecified, no limit will be set.
</div>
}
onChange={(e) => {
const value = e.target.value;
// Allow only integer values
if (value === "" || /^[0-9]+$/.test(value)) {
setFieldValue("token_budget", value);
}
}}
/>
<TextFormField
name="token_budget_time_period"
label="Token Budget Time Period (hours)"
subtext={
<div>
Specify the length of the time period, in hours, over
which the token budget will be applied.
</div>
}
onChange={(e) => {
const value = e.target.value;
// Allow only integer values
if (value === "" || /^[0-9]+$/.test(value)) {
setFieldValue("token_budget_time_period", value);
}
}}
/>
</>
)}
</>
<div className="flex">
<Button
className="w-64 mx-auto"
type="submit"
disabled={isSubmitting}
>
Submit
</Button>
</div>
</Form>
);
}}
</Formik>
</>
);
};
const Page = () => {
return (
<div className="mx-auto container">
@ -181,7 +14,6 @@ const Page = () => {
<LLMConfiguration />
<LLMOptions />
</div>
);
};

View File

@ -0,0 +1,175 @@
"use client";
import * as Yup from "yup";
import { Button } from "@tremor/react";
import { useEffect, useState } from "react";
import { Modal } from "@/components/Modal";
import { Form, Formik } from "formik";
import {
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { UserGroup } from "@/lib/types";
import { Scope } from "./types";
import { PopupSpec } from "@/components/admin/connectors/Popup";
interface CreateRateLimitModalProps {
isOpen: boolean;
setIsOpen: (isOpen: boolean) => void;
onSubmit: (
target_scope: Scope,
period_hours: number,
token_budget: number,
group_id: number
) => void;
setPopup: (popupSpec: PopupSpec | null) => void;
forSpecificScope?: Scope;
forSpecificUserGroup?: number;
}
export const CreateRateLimitModal = ({
isOpen,
setIsOpen,
onSubmit,
setPopup,
forSpecificScope,
forSpecificUserGroup,
}: CreateRateLimitModalProps) => {
const [modalUserGroups, setModalUserGroups] = useState([]);
const [shouldFetchUserGroups, setShouldFetchUserGroups] = useState(
forSpecificScope === Scope.USER_GROUP
);
useEffect(() => {
const fetchData = async () => {
try {
const response = await fetch("/api/manage/admin/user-group");
const data = await response.json();
const options = data.map((userGroup: UserGroup) => ({
name: userGroup.name,
value: userGroup.id,
}));
setModalUserGroups(options);
setShouldFetchUserGroups(false);
} catch (error) {
setPopup({
type: "error",
message: `Failed to fetch user groups: ${error}`,
});
}
};
if (shouldFetchUserGroups) {
fetchData();
}
}, [shouldFetchUserGroups, setPopup]);
if (!isOpen) {
return null;
}
return (
<Modal
title={"Create a Token Rate Limit"}
onOutsideClick={() => setIsOpen(false)}
width="w-2/6"
>
<Formik
initialValues={{
enabled: true,
period_hours: "",
token_budget: "",
target_scope: forSpecificScope || Scope.GLOBAL,
user_group_id: forSpecificUserGroup,
}}
validationSchema={Yup.object().shape({
period_hours: Yup.number()
.required("Time Window is a required field")
.min(1, "Time Window must be at least 1 hour"),
token_budget: Yup.number()
.required("Token Budget is a required field")
.min(1, "Token Budget must be at least 1"),
target_scope: Yup.string().required(
"Target Scope is a required field"
),
user_group_id: Yup.string().test(
"user_group_id",
"User Group is a required field",
(value, context) => {
return (
context.parent.target_scope !== "user_group" ||
(context.parent.target_scope === "user_group" &&
value !== undefined)
);
}
),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
onSubmit(
values.target_scope,
Number(values.period_hours),
Number(values.token_budget),
Number(values.user_group_id)
);
return formikHelpers.setSubmitting(false);
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form>
{!forSpecificScope && (
<SelectorFormField
name="target_scope"
label="Target Scope"
options={[
{ name: "Global", value: Scope.GLOBAL },
{ name: "User", value: Scope.USER },
{ name: "User Group", value: Scope.USER_GROUP },
]}
includeDefault={false}
onSelect={(selected) => {
setFieldValue("target_scope", selected)
if (selected === Scope.USER_GROUP) {
setShouldFetchUserGroups(true);
}
}}
/>
)}
{forSpecificUserGroup === undefined &&
values.target_scope === Scope.USER_GROUP && (
<SelectorFormField
name="user_group_id"
label="User Group"
options={modalUserGroups}
includeDefault={false}
/>
)}
<TextFormField
name="period_hours"
label="Time Window (Hours)"
type="number"
placeholder=""
/>
<TextFormField
name="token_budget"
label="Token Budget (Thousands)"
type="number"
placeholder=""
/>
<div className="flex">
<Button
type="submit"
size="xs"
color="green"
disabled={isSubmitting}
className="mx-auto w-64"
>
Create!
</Button>
</div>
</Form>
)}
</Formik>
</Modal>
);
};

View File

@ -0,0 +1,169 @@
"use client";
import {
Table,
TableHead,
TableRow,
TableHeaderCell,
TableBody,
TableCell,
Title,
Text,
} from "@tremor/react";
import { DeleteButton } from "@/components/DeleteButton";
import { deleteTokenRateLimit, updateTokenRateLimit } from "./lib";
import { ThreeDotsLoader } from "@/components/Loading";
import { TokenRateLimitDisplay } from "./types";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { CustomCheckbox } from "@/components/CustomCheckbox";
type TokenRateLimitTableArgs = {
tokenRateLimits: TokenRateLimitDisplay[];
title?: string;
description?: string;
fetchUrl: string;
hideHeading?: boolean;
};
export const TokenRateLimitTable = ({
tokenRateLimits,
title,
description,
fetchUrl,
hideHeading,
}: TokenRateLimitTableArgs) => {
const shouldRenderGroupName = () =>
tokenRateLimits.length > 0 && tokenRateLimits[0].group_name !== undefined;
const handleEnabledChange = (id: number) => {
const tokenRateLimit = tokenRateLimits.find(
(tokenRateLimit) => tokenRateLimit.token_id === id
);
if (!tokenRateLimit) {
return;
}
updateTokenRateLimit(id, {
token_budget: tokenRateLimit.token_budget,
period_hours: tokenRateLimit.period_hours,
enabled: !tokenRateLimit.enabled,
}).then(() => {
mutate(fetchUrl);
});
};
const handleDelete = (id: number) =>
deleteTokenRateLimit(id).then(() => {
mutate(fetchUrl);
});
if (tokenRateLimits.length === 0) {
return (
<div>
{!hideHeading && title && <Title>{title}</Title>}
{!hideHeading && description && (
<Text className="my-2">{description}</Text>
)}
<Text className={`${!hideHeading && "my-8"}`}>
No token rate limits set!
</Text>
</div>
);
}
return (
<div>
{!hideHeading && title && <Title>{title}</Title>}
{!hideHeading && description && (
<Text className="my-2">{description}</Text>
)}
<Table className={`overflow-visible ${!hideHeading && "my-8"}`}>
<TableHead>
<TableRow>
<TableHeaderCell>Enabled</TableHeaderCell>
{shouldRenderGroupName() && (
<TableHeaderCell>Group Name</TableHeaderCell>
)}
<TableHeaderCell>Time Window (Hours)</TableHeaderCell>
<TableHeaderCell>Token Budget (Thousands)</TableHeaderCell>
<TableHeaderCell>Delete</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{tokenRateLimits.map((tokenRateLimit) => {
return (
<TableRow key={tokenRateLimit.token_id}>
<TableCell>
<div
onClick={() => handleEnabledChange(tokenRateLimit.token_id)}
className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-24 flex"
>
<div className="mx-auto flex">
<CustomCheckbox checked={tokenRateLimit.enabled} />
<p className="ml-2">
{tokenRateLimit.enabled ? "Enabled" : "Disabled"}
</p>
</div>
</div>
</TableCell>
{shouldRenderGroupName() && (
<TableCell className="font-bold text-emphasis">
{tokenRateLimit.group_name}
</TableCell>
)}
<TableCell>{tokenRateLimit.period_hours}</TableCell>
<TableCell>{tokenRateLimit.token_budget}</TableCell>
<TableCell>
<DeleteButton
onClick={() => handleDelete(tokenRateLimit.token_id)}
/>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
};
export const GenericTokenRateLimitTable = ({
fetchUrl,
title,
description,
hideHeading,
responseMapper,
}: {
fetchUrl: string;
title?: string;
description?: string;
hideHeading?: boolean;
responseMapper?: (data: any) => TokenRateLimitDisplay[];
}) => {
const { data, isLoading, error } = useSWR(fetchUrl, errorHandlingFetcher);
if (isLoading) {
return <ThreeDotsLoader />;
}
if (!isLoading && error) {
return <Text>Failed to load token rate limits</Text>;
}
let processedData = data;
if (responseMapper) {
processedData = responseMapper(data);
}
return (
<TokenRateLimitTable
tokenRateLimits={processedData}
fetchUrl={fetchUrl}
title={title}
description={description}
hideHeading={hideHeading}
/>
);
};

View File

@ -0,0 +1,64 @@
import { TokenRateLimitArgs } from "./types";
const API_PREFIX = "/api/admin/token-rate-limits";
// Global Token Limits
export const insertGlobalTokenRateLimit = async (
tokenRateLimit: TokenRateLimitArgs
) => {
return await fetch(`${API_PREFIX}/global`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(tokenRateLimit),
});
};
// User Token Limits
export const insertUserTokenRateLimit = async (
tokenRateLimit: TokenRateLimitArgs
) => {
return await fetch(`${API_PREFIX}/users`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(tokenRateLimit),
});
};
// User Group Token Limits (EE Only)
export const insertGroupTokenRateLimit = async (
tokenRateLimit: TokenRateLimitArgs,
group_id: number
) => {
return await fetch(`${API_PREFIX}/user-group/${group_id}`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(tokenRateLimit),
});
};
// Common Endpoints
export const deleteTokenRateLimit = async (token_rate_limit_id: number) => {
return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, {
method: "DELETE",
});
};
export const updateTokenRateLimit = async (
token_rate_limit_id: number,
tokenRateLimit: TokenRateLimitArgs
) => {
return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(tokenRateLimit),
});
};

View File

@ -0,0 +1,223 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import {
Button,
Tab,
TabGroup,
TabList,
TabPanel,
TabPanels,
Text,
} from "@tremor/react";
import { useState } from "react";
import { FiGlobe, FiShield, FiUser, FiUsers } from "react-icons/fi";
import {
insertGlobalTokenRateLimit,
insertGroupTokenRateLimit,
insertUserTokenRateLimit,
} from "./lib";
import { Scope, TokenRateLimit } from "./types";
import { GenericTokenRateLimitTable } from "./TokenRateLimitTables";
import { mutate } from "swr";
import { usePopup } from "@/components/admin/connectors/Popup";
import { CreateRateLimitModal } from "./CreateRateLimitModal";
import { EE_ENABLED } from "@/lib/constants";
const BASE_URL = "/api/admin/token-rate-limits";
const GLOBAL_TOKEN_FETCH_URL = `${BASE_URL}/global`;
const USER_TOKEN_FETCH_URL = `${BASE_URL}/users`;
const USER_GROUP_FETCH_URL = `${BASE_URL}/user-groups`;
const GLOBAL_DESCRIPTION =
"Global rate limits apply to all users, user groups, and API keys. When the global \
rate limit is reached, no more tokens can be spent.";
const USER_DESCRIPTION =
"User rate limits apply to individual users. When a user reaches a limit, they will \
be temporarily blocked from spending tokens.";
const USER_GROUP_DESCRIPTION =
"User group rate limits apply to all users in a group. When a group reaches a limit, \
all users in the group will be temporarily blocked from spending tokens, regardless \
of their individual limits. If a user is in multiple groups, the most lenient limit \
will apply.";
const handleCreateTokenRateLimit = async (
target_scope: Scope,
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
const tokenRateLimitArgs = {
enabled: true,
token_budget: token_budget,
period_hours: period_hours,
};
if (target_scope === Scope.GLOBAL) {
return await insertGlobalTokenRateLimit(tokenRateLimitArgs);
} else if (target_scope === Scope.USER) {
return await insertUserTokenRateLimit(tokenRateLimitArgs);
} else if (target_scope === Scope.USER_GROUP) {
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
} else {
throw new Error(`Invalid target_scope: ${target_scope}`);
}
};
function Main() {
const [tabIndex, setTabIndex] = useState(0);
const [modalIsOpen, setModalIsOpen] = useState(false);
const { popup, setPopup } = usePopup();
const updateTable = (target_scope: Scope) => {
if (target_scope === Scope.GLOBAL) {
mutate(GLOBAL_TOKEN_FETCH_URL);
setTabIndex(0);
} else if (target_scope === Scope.USER) {
mutate(USER_TOKEN_FETCH_URL);
setTabIndex(1);
} else if (target_scope === Scope.USER_GROUP) {
mutate(USER_GROUP_FETCH_URL);
setTabIndex(2);
}
};
const handleSubmit = (
target_scope: Scope,
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
handleCreateTokenRateLimit(
target_scope,
period_hours,
token_budget,
group_id
)
.then(() => {
setModalIsOpen(false);
setPopup({ type: "success", message: "Token rate limit created!" });
updateTable(target_scope);
})
.catch((error) => {
setPopup({ type: "error", message: error.message });
});
};
return (
<div>
{popup}
<Text className="mb-2">
Token rate limits enable you control how many tokens can be spent in a
given time period. With token rate limits, you can:
</Text>
<ul className="list-disc mt-2 ml-4 mb-2">
<li>
<Text>
Set a global rate limit to control your organization&apos;s overall
token spend.
</Text>
</li>
{EE_ENABLED && (
<>
<li>
<Text>
Set rate limits for users to ensure that no single user can
spend too many tokens.
</Text>
</li>
<li>
<Text>
Set rate limits for user groups to control token spend for your
teams.
</Text>
</li>
</>
)}
<li>
<Text>Enable and disable rate limits on the fly.</Text>
</li>
</ul>
<Button
color="green"
size="xs"
className="mt-3"
onClick={() => setModalIsOpen(true)}
>
Create a Token Rate Limit
</Button>
{EE_ENABLED && (
<TabGroup className="mt-6" index={tabIndex} onIndexChange={setTabIndex}>
<TabList variant="line">
<Tab icon={FiGlobe}>Global</Tab>
<Tab icon={FiUser}>User</Tab>
<Tab icon={FiUsers}>User Groups</Tab>
</TabList>
<TabPanels className="mt-6">
<TabPanel>
<GenericTokenRateLimitTable
fetchUrl={GLOBAL_TOKEN_FETCH_URL}
title={"Global Token Rate Limits"}
description={GLOBAL_DESCRIPTION}
/>
</TabPanel>
<TabPanel>
<GenericTokenRateLimitTable
fetchUrl={USER_TOKEN_FETCH_URL}
title={"User Token Rate Limits"}
description={USER_DESCRIPTION}
/>
</TabPanel>
<TabPanel>
<GenericTokenRateLimitTable
fetchUrl={USER_GROUP_FETCH_URL}
title={"User Group Token Rate Limits"}
description={USER_GROUP_DESCRIPTION}
responseMapper={(data: Record<string, TokenRateLimit[]>) =>
Object.entries(data).flatMap(([group_name, elements]) =>
elements.map((element) => ({
...element,
group_name,
}))
)
}
/>
</TabPanel>
</TabPanels>
</TabGroup>
)}
{!EE_ENABLED && (
<div className="mt-6">
<GenericTokenRateLimitTable
fetchUrl={GLOBAL_TOKEN_FETCH_URL}
title={"Global Token Rate Limits"}
description={GLOBAL_DESCRIPTION}
/>
</div>
)}
<CreateRateLimitModal
isOpen={modalIsOpen}
setIsOpen={() => setModalIsOpen(false)}
setPopup={setPopup}
onSubmit={handleSubmit}
forSpecificScope={EE_ENABLED ? undefined : Scope.GLOBAL}
/>
</div>
);
}
export default function Page() {
return (
<div className="mx-auto container">
<AdminPageTitle title="Token Rate Limits" icon={<FiShield size={32} />} />
<Main />
</div>
);
}

View File

@ -0,0 +1,22 @@
export enum Scope {
USER = "user",
USER_GROUP = "user_group",
GLOBAL = "global",
}
export interface TokenRateLimitArgs {
enabled: boolean;
token_budget: number;
period_hours: number;
}
export interface TokenRateLimit {
token_id: number;
enabled: boolean;
token_budget: number;
period_hours: number;
}
export interface TokenRateLimitDisplay extends TokenRateLimit {
group_name?: string;
}

View File

@ -0,0 +1,60 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { CreateRateLimitModal } from "../../../../admin/token-rate-limits/CreateRateLimitModal";
import { Scope } from "../../../../admin/token-rate-limits/types";
import { insertGroupTokenRateLimit } from "../../../../admin/token-rate-limits/lib";
import { mutate } from "swr";
interface AddMemberFormProps {
isOpen: boolean;
setIsOpen: (isOpen: boolean) => void;
setPopup: (popupSpec: PopupSpec | null) => void;
userGroupId: number;
}
const handleCreateGroupTokenRateLimit = async (
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
const tokenRateLimitArgs = {
enabled: true,
token_budget: token_budget,
period_hours: period_hours,
};
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
};
export const AddTokenRateLimitForm: React.FC<AddMemberFormProps> = ({
isOpen,
setIsOpen,
setPopup,
userGroupId,
}) => {
const handleSubmit = (
_: Scope,
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
handleCreateGroupTokenRateLimit(period_hours, token_budget, group_id)
.then(() => {
setIsOpen(false);
setPopup({ type: "success", message: "Token rate limit created!" });
mutate(`/api/admin/token-rate-limits/user-group/${userGroupId}`);
})
.catch((error) => {
setPopup({ type: "error", message: error.message });
});
};
return (
<CreateRateLimitModal
isOpen={isOpen}
setIsOpen={setIsOpen}
onSubmit={handleSubmit}
setPopup={setPopup}
forSpecificScope={Scope.USER_GROUP}
forSpecificUserGroup={userGroupId}
/>
);
};

View File

@ -22,6 +22,8 @@ import {
import { DeleteButton } from "@/components/DeleteButton";
import { Bubble } from "@/components/Bubble";
import { BookmarkIcon, RobotIcon } from "@/components/icons/icons";
import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm";
import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables";
interface GroupDisplayProps {
users: User[];
@ -39,6 +41,7 @@ export const GroupDisplay = ({
const { popup, setPopup } = usePopup();
const [addMemberFormVisible, setAddMemberFormVisible] = useState(false);
const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false);
const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false);
return (
<div>
@ -301,6 +304,31 @@ export const GroupDisplay = ({
</>
)}
</div>
<Divider />
<h2 className="text-xl font-bold mt-8 mb-2">Token Rate Limits</h2>
<AddTokenRateLimitForm
isOpen={addRateLimitFormVisible}
setIsOpen={setAddRateLimitFormVisible}
setPopup={setPopup}
userGroupId={userGroup.id}
/>
<GenericTokenRateLimitTable
fetchUrl={`/api/admin/token-rate-limits/user-group/${userGroup.id}`}
hideHeading
/>
<Button
color="green"
size="xs"
className="mt-3"
onClick={() => setAddRateLimitFormVisible(true)}
>
Create a Token Rate Limit
</Button>
</div>
);
};

View File

@ -28,6 +28,7 @@ import {
FiImage,
FiPackage,
FiSettings,
FiShield,
FiSlack,
FiTool,
} from "react-icons/fi";
@ -215,6 +216,15 @@ export async function Layout({ children }: { children: React.ReactNode }) {
},
]
: []),
{
name: (
<div className="flex">
<FiShield size={18} />
<div className="ml-1">Token Rate Limits</div>
</div>
),
link: "/admin/token-rate-limits",
},
],
},
...(EE_ENABLED