From d7a704c0d92c75dc114be30edfc4ef7790b6cdcd Mon Sep 17 00:00:00 2001 From: Alan Hagedorn Date: Sun, 14 Apr 2024 18:53:38 -0700 Subject: [PATCH] Token Rate Limiting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WIP Cleanup 🧹 Remove existing rate limiting logic Cleanup 🧼 Undo nit Cleanup 🧽 Move db constants (avoids circular import) WIP WIP Cleanup Lint Resolve alembic conflict Fix mypy Add backfill to migration Update comment Make unauthenticated users still adhere to global limits Use Depends Remove enum from table Update migration error handling + deletion Address verbal feedback, cleanup urls, minor nits --- backend/danswer/configs/constants.py | 4 - backend/danswer/main.py | 6 + .../danswer/server/manage/administrative.py | 45 ---- .../server/query_and_chat/chat_backend.py | 2 + .../server/query_and_chat/query_backend.py | 4 +- .../server/query_and_chat/token_budget.py | 79 ------- .../server/query_and_chat/token_limit.py | 135 +++++++++++ .../danswer/server/token_rate_limits/api.py | 79 +++++++ .../server/token_rate_limits/models.py | 25 ++ backend/ee/danswer/db/token_limit.py | 176 ++++++++++++++ backend/ee/danswer/db/user_group.py | 20 ++ backend/ee/danswer/main.py | 7 + .../server/query_and_chat/token_limit.py | 184 +++++++++++++++ .../danswer/server/token_rate_limits/api.py | 105 +++++++++ web/next.config.js | 5 + web/src/app/admin/models/llm/page.tsx | 168 ------------- .../CreateRateLimitModal.tsx | 175 ++++++++++++++ .../TokenRateLimitTables.tsx | 169 +++++++++++++ web/src/app/admin/token-rate-limits/lib.ts | 64 +++++ web/src/app/admin/token-rate-limits/page.tsx | 223 ++++++++++++++++++ web/src/app/admin/token-rate-limits/types.ts | 22 ++ .../[groupId]/AddTokenRateLimitForm.tsx | 60 +++++ .../admin/groups/[groupId]/GroupDisplay.tsx | 28 +++ web/src/components/admin/Layout.tsx | 10 + 24 files changed, 1497 insertions(+), 298 deletions(-) delete mode 100644 backend/danswer/server/query_and_chat/token_budget.py create mode 100644 backend/danswer/server/query_and_chat/token_limit.py create mode 100644 backend/danswer/server/token_rate_limits/api.py create mode 100644 backend/danswer/server/token_rate_limits/models.py create mode 100644 backend/ee/danswer/db/token_limit.py create mode 100644 backend/ee/danswer/server/query_and_chat/token_limit.py create mode 100644 backend/ee/danswer/server/token_rate_limits/api.py create mode 100644 web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx create mode 100644 web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx create mode 100644 web/src/app/admin/token-rate-limits/lib.ts create mode 100644 web/src/app/admin/token-rate-limits/page.tsx create mode 100644 web/src/app/admin/token-rate-limits/types.ts create mode 100644 web/src/app/ee/admin/groups/[groupId]/AddTokenRateLimitForm.tsx diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 9c707d43a0..1eab4b124d 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -41,10 +41,6 @@ DEFAULT_BOOST = 0 SESSION_KEY = "session" QUERY_EVENT_ID = "query_event_id" LLM_CHUNKS = "llm_chunks" -TOKEN_BUDGET = "token_budget" -TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period" -ENABLE_TOKEN_BUDGET = "enable_token_budget" -TOKEN_BUDGET_SETTINGS = "token_budget_settings" # For chunking/processing chunks TITLE_SEPARATOR = "\n\r\n" diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 98cc6e2010..6321894e40 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -83,6 +83,9 @@ from danswer.server.settings.api import basic_router as settings_router from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache +from danswer.server.token_rate_limits.api import ( + router as token_rate_limit_settings_router, +) from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType @@ -281,6 +284,9 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, settings_admin_router) include_router_with_global_prefix_prepended(application, llm_admin_router) include_router_with_global_prefix_prepended(application, llm_router) + include_router_with_global_prefix_prepended( + application, token_rate_limit_settings_router + ) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index c60206ca3f..f8d4cd8aa3 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -1,23 +1,16 @@ -import json from datetime import datetime from datetime import timedelta from datetime import timezone from typing import cast from fastapi import APIRouter -from fastapi import Body from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ -from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED from danswer.configs.constants import DocumentSource -from danswer.configs.constants import ENABLE_TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET_SETTINGS -from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.engine import get_session @@ -193,41 +186,3 @@ def create_deletion_attempt_for_connector_id( file_store = get_default_file_store(db_session) for file_name in connector.connector_specific_config["file_locations"]: file_store.delete_file(file_name) - - -@router.get("/admin/token-budget-settings") -def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict: - if not TOKEN_BUDGET_GLOBALLY_ENABLED: - raise HTTPException( - status_code=400, detail="Token budget is not enabled in the application." - ) - - try: - settings_json = cast( - str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) - ) - settings = json.loads(settings_json) - return settings - except ConfigNotFoundError: - raise HTTPException(status_code=404, detail="Token budget settings not found.") - - -@router.put("/admin/token-budget-settings") -def update_token_budget_settings( - _: User = Depends(current_admin_user), - enable_token_budget: bool = Body(..., embed=True), - token_budget: int = Body(..., ge=0, embed=True), # Ensure non-negative - token_budget_time_period: int = Body(..., ge=1, embed=True), # Ensure positive -) -> dict[str, str]: - # Prepare the settings as a JSON string - settings_json = json.dumps( - { - ENABLE_TOKEN_BUDGET: enable_token_budget, - TOKEN_BUDGET: token_budget, - TOKEN_BUDGET_TIME_PERIOD: token_budget_time_period, - } - ) - - # Store the settings in the dynamic config store - get_dynamic_config_store().store(TOKEN_BUDGET_SETTINGS, settings_json) - return {"message": "Token budget settings updated successfully."} diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 090c920128..57c63b14d9 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -64,6 +64,7 @@ from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest +from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger logger = setup_logger() @@ -275,6 +276,7 @@ def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, request: Request, user: User | None = Depends(current_user), + _: None = Depends(check_token_rate_limits), ) -> StreamingResponse: """This endpoint is both used for all the following purposes: - Sending a new message in the session diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index a151c7d30d..43192211b7 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -29,7 +29,7 @@ from danswer.server.query_and_chat.models import QueryValidationResponse from danswer.server.query_and_chat.models import SimpleQueryRequest from danswer.server.query_and_chat.models import SourceTag from danswer.server.query_and_chat.models import TagResponse -from danswer.server.query_and_chat.token_budget import check_token_budget +from danswer.server.query_and_chat.token_limit import check_token_rate_limits from danswer.utils.logger import setup_logger logger = setup_logger() @@ -150,7 +150,7 @@ def stream_query_validation( def get_answer_with_quote( query_request: DirectQARequest, user: User = Depends(current_user), - _: bool = Depends(check_token_budget), + _: None = Depends(check_token_rate_limits), ) -> StreamingResponse: query = query_request.messages[0].message logger.info(f"Received query for one shot answer with quotes: {query}") diff --git a/backend/danswer/server/query_and_chat/token_budget.py b/backend/danswer/server/query_and_chat/token_budget.py deleted file mode 100644 index 49a84f5b0e..0000000000 --- a/backend/danswer/server/query_and_chat/token_budget.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -from datetime import datetime -from datetime import timedelta -from typing import cast - -from fastapi import HTTPException -from sqlalchemy import func -from sqlalchemy.orm import Session - -from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED -from danswer.configs.constants import ENABLE_TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET -from danswer.configs.constants import TOKEN_BUDGET_SETTINGS -from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD -from danswer.db.engine import get_session_context_manager -from danswer.db.models import ChatMessage -from danswer.dynamic_configs.factory import get_dynamic_config_store - -BUDGET_LIMIT_DEFAULT = -1 # Default to no limit -TIME_PERIOD_HOURS_DEFAULT = 12 - - -def is_under_token_budget(db_session: Session) -> bool: - try: - settings_json = cast( - str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS) - ) - except Exception: - return True - - settings = json.loads(settings_json) - - is_enabled = settings.get(ENABLE_TOKEN_BUDGET, False) - - if not is_enabled: - return True - - budget_limit = settings.get(TOKEN_BUDGET, -1) - - if budget_limit < 0: - return True - - period_hours = settings.get(TOKEN_BUDGET_TIME_PERIOD, TIME_PERIOD_HOURS_DEFAULT) - period_start_time = datetime.now() - timedelta(hours=period_hours) - - # Fetch the sum of all tokens used within the period - token_sum = ( - db_session.query(func.sum(ChatMessage.token_count)) - .filter(ChatMessage.time_sent >= period_start_time) - .scalar() - or 0 - ) - - print( - "token_sum:", - token_sum, - "budget_limit:", - budget_limit, - "period_hours:", - period_hours, - "period_start_time:", - period_start_time, - ) - - return token_sum < ( - budget_limit * 1000 - ) # Budget limit is expressed in thousands of tokens - - -def check_token_budget() -> None: - if not TOKEN_BUDGET_GLOBALLY_ENABLED: - return None - - with get_session_context_manager() as db_session: - # Perform the token budget check here, possibly using `user` and `db_session` for database access if needed - if not is_under_token_budget(db_session): - raise HTTPException( - status_code=429, detail="Sorry, token budget exceeded. Try again later." - ) diff --git a/backend/danswer/server/query_and_chat/token_limit.py b/backend/danswer/server/query_and_chat/token_limit.py new file mode 100644 index 0000000000..b44eec3a64 --- /dev/null +++ b/backend/danswer/server/query_and_chat/token_limit.py @@ -0,0 +1,135 @@ +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from functools import lru_cache + +from dateutil import tz +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.engine import get_session_context_manager +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession +from danswer.db.models import TokenRateLimit +from danswer.db.models import User +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation +from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits + + +logger = setup_logger() + + +TOKEN_BUDGET_UNIT = 1_000 + + +def check_token_rate_limits( + user: User | None = Depends(current_user), +) -> None: + # short circuit if no rate limits are set up + # NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time + if not any_rate_limit_exists(): + return + + versioned_rate_limit_strategy = fetch_versioned_implementation( + "danswer.server.query_and_chat.token_limit", "_check_token_rate_limits" + ) + return versioned_rate_limit_strategy(user) + + +def _check_token_rate_limits(_: User | None) -> None: + _user_is_rate_limited_by_global() + + +""" +Global rate limits +""" + + +def _user_is_rate_limited_by_global() -> None: + with get_session_context_manager() as db_session: + global_rate_limits = fetch_all_global_token_rate_limits( + db_session=db_session, enabled_only=True, ordered=False + ) + + if global_rate_limits: + global_cutoff_time = _get_cutoff_time(global_rate_limits) + global_usage = _fetch_global_usage(global_cutoff_time, db_session) + + if _is_rate_limited(global_rate_limits, global_usage): + raise HTTPException( + status_code=429, + detail="Token budget exceeded for organization. Try again later.", + ) + + +def _fetch_global_usage( + cutoff_time: datetime, db_session: Session +) -> Sequence[tuple[datetime, int]]: + """ + Fetch global token usage within the cutoff time, grouped by minute + """ + result = db_session.execute( + select( + func.date_trunc("minute", ChatMessage.time_sent), + func.sum(ChatMessage.token_count), + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .filter( + ChatMessage.time_sent >= cutoff_time, + ) + .group_by(func.date_trunc("minute", ChatMessage.time_sent)) + ).all() + + return [(row[0], row[1]) for row in result] + + +""" +Common functions +""" + + +def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime: + max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits) + return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours) + + +def _is_rate_limited( + rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]] +) -> bool: + """ + If at least one rate limit is exceeded, return True + """ + for rate_limit in rate_limits: + tokens_used = sum( + u_token_count + for u_date, u_token_count in usage + if u_date + >= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours) + ) + + if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT: + return True + + return False + + +@lru_cache() +def any_rate_limit_exists() -> bool: + """Checks if any rate limit exists in the database. Is cached, so that if no rate limits + are setup, we don't have any effect on average query latency.""" + logger.info("Checking for any rate limits...") + with get_session_context_manager() as db_session: + return ( + db_session.scalar( + select(TokenRateLimit.id).where( + TokenRateLimit.enabled == True # noqa: E712 + ) + ) + is not None + ) diff --git a/backend/danswer/server/token_rate_limits/api.py b/backend/danswer/server/token_rate_limits/api.py new file mode 100644 index 0000000000..245e339141 --- /dev/null +++ b/backend/danswer/server/token_rate_limits/api.py @@ -0,0 +1,79 @@ +from fastapi import APIRouter +from fastapi import Depends +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.query_and_chat.token_limit import any_rate_limit_exists +from danswer.server.token_rate_limits.models import TokenRateLimitArgs +from danswer.server.token_rate_limits.models import TokenRateLimitDisplay +from ee.danswer.db.token_limit import delete_token_rate_limit +from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits +from ee.danswer.db.token_limit import insert_global_token_rate_limit +from ee.danswer.db.token_limit import update_token_rate_limit + +router = APIRouter(prefix="/admin/token-rate-limits") + + +""" +Global Token Limit Settings +""" + + +@router.get("/global") +def get_global_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_global_token_rate_limits(db_session) + ] + + +@router.post("/global") +def create_global_token_limit_settings( + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_global_token_rate_limit(db_session, token_limit_settings) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display + + +""" +General Token Limit Settings +""" + + +@router.put("/rate-limit/{token_rate_limit_id}") +def update_token_limit_settings( + token_rate_limit_id: int, + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + return TokenRateLimitDisplay.from_db( + update_token_rate_limit( + db_session=db_session, + token_rate_limit_id=token_rate_limit_id, + token_rate_limit_settings=token_limit_settings, + ) + ) + + +@router.delete("/rate-limit/{token_rate_limit_id}") +def delete_token_limit_settings( + token_rate_limit_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + return delete_token_rate_limit( + db_session=db_session, + token_rate_limit_id=token_rate_limit_id, + ) diff --git a/backend/danswer/server/token_rate_limits/models.py b/backend/danswer/server/token_rate_limits/models.py new file mode 100644 index 0000000000..351abe92e7 --- /dev/null +++ b/backend/danswer/server/token_rate_limits/models.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel + +from danswer.db.models import TokenRateLimit + + +class TokenRateLimitArgs(BaseModel): + enabled: bool + token_budget: int + period_hours: int + + +class TokenRateLimitDisplay(BaseModel): + token_id: int + enabled: bool + token_budget: int + period_hours: int + + @classmethod + def from_db(cls, token_rate_limit: TokenRateLimit) -> "TokenRateLimitDisplay": + return cls( + token_id=token_rate_limit.id, + enabled=token_rate_limit.enabled, + token_budget=token_rate_limit.token_budget, + period_hours=token_rate_limit.period_hours, + ) diff --git a/backend/ee/danswer/db/token_limit.py b/backend/ee/danswer/db/token_limit.py new file mode 100644 index 0000000000..9b15381163 --- /dev/null +++ b/backend/ee/danswer/db/token_limit.py @@ -0,0 +1,176 @@ +from collections.abc import Sequence + +from sqlalchemy import Row +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.constants import TokenRateLimitScope +from danswer.db.models import TokenRateLimit +from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import UserGroup +from danswer.server.token_rate_limits.models import TokenRateLimitArgs + + +def fetch_all_user_token_rate_limits( + db_session: Session, + enabled_only: bool = False, + ordered: bool = True, +) -> Sequence[TokenRateLimit]: + query = select(TokenRateLimit).where( + TokenRateLimit.scope == TokenRateLimitScope.USER + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + return db_session.scalars(query).all() + + +def fetch_all_global_token_rate_limits( + db_session: Session, + enabled_only: bool = False, + ordered: bool = True, +) -> Sequence[TokenRateLimit]: + query = select(TokenRateLimit).where( + TokenRateLimit.scope == TokenRateLimitScope.GLOBAL + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + token_rate_limits = db_session.scalars(query).all() + return token_rate_limits + + +def fetch_all_user_group_token_rate_limits( + db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True +) -> Sequence[TokenRateLimit]: + query = ( + select(TokenRateLimit) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .where( + TokenRateLimit__UserGroup.user_group_id == group_id, + TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP, + ) + ) + + if enabled_only: + query = query.where(TokenRateLimit.enabled.is_(True)) + + if ordered: + query = query.order_by(TokenRateLimit.created_at.desc()) + + token_rate_limits = db_session.scalars(query).all() + return token_rate_limits + + +def fetch_all_user_group_token_rate_limits_by_group( + db_session: Session, +) -> Sequence[Row[tuple[TokenRateLimit, str]]]: + query = ( + select(TokenRateLimit, UserGroup.name) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id) + ) + + return db_session.execute(query).all() + + +def insert_user_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.USER, + ) + db_session.add(token_limit) + db_session.commit() + + return token_limit + + +def insert_global_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.GLOBAL, + ) + db_session.add(token_limit) + db_session.commit() + + return token_limit + + +def insert_user_group_token_rate_limit( + db_session: Session, + token_rate_limit_settings: TokenRateLimitArgs, + group_id: int, +) -> TokenRateLimit: + token_limit = TokenRateLimit( + enabled=token_rate_limit_settings.enabled, + token_budget=token_rate_limit_settings.token_budget, + period_hours=token_rate_limit_settings.period_hours, + scope=TokenRateLimitScope.USER_GROUP, + ) + db_session.add(token_limit) + db_session.flush() + + rate_limit = TokenRateLimit__UserGroup( + rate_limit_id=token_limit.id, user_group_id=group_id + ) + db_session.add(rate_limit) + db_session.commit() + + return token_limit + + +def update_token_rate_limit( + db_session: Session, + token_rate_limit_id: int, + token_rate_limit_settings: TokenRateLimitArgs, +) -> TokenRateLimit: + token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) + if token_limit is None: + raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") + + token_limit.enabled = token_rate_limit_settings.enabled + token_limit.token_budget = token_rate_limit_settings.token_budget + token_limit.period_hours = token_rate_limit_settings.period_hours + db_session.commit() + + return token_limit + + +def delete_token_rate_limit( + db_session: Session, + token_rate_limit_id: int, +) -> None: + token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) + if token_limit is None: + raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") + + db_session.query(TokenRateLimit__UserGroup).filter( + TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id + ).delete() + + db_session.delete(token_limit) + db_session.commit() diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index dba2ea5c80..beff6c58c6 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document from danswer.db.models import DocumentByConnectorCredentialPair +from danswer.db.models import TokenRateLimit__UserGroup from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.db.models import UserGroup @@ -241,6 +242,21 @@ def update_user_group( return db_user_group +def _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session: Session, user_group_id: int +) -> None: + """NOTE: does not commit the transaction.""" + token_rate_limit__user_group_relationships = db_session.scalars( + select(TokenRateLimit__UserGroup).where( + TokenRateLimit__UserGroup.user_group_id == user_group_id + ) + ).all() + for ( + token_rate_limit__user_group_relationship + ) in token_rate_limit__user_group_relationships: + db_session.delete(token_rate_limit__user_group_relationship) + + def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) @@ -255,6 +271,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session=db_session, user_group_id=user_group_id ) + _cleanup_token_rate_limit__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) + db_user_group.is_up_to_date = False db_user_group.is_up_for_deletion = True db_session.commit() diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 417c5afd8e..a8d458428b 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -34,6 +34,9 @@ from ee.danswer.server.query_and_chat.query_backend import ( ) from ee.danswer.server.query_history.api import router as query_history_router from ee.danswer.server.saml import router as saml_router +from ee.danswer.server.token_rate_limits.api import ( + router as token_rate_limit_settings_router, +) from ee.danswer.server.user_group.api import router as user_group_router logger = setup_logger() @@ -85,6 +88,10 @@ def get_ee_application() -> FastAPI: include_router_with_global_prefix_prepended( application, enterprise_settings_admin_router ) + # Token rate limit settings + include_router_with_global_prefix_prepended( + application, token_rate_limit_settings_router + ) include_router_with_global_prefix_prepended(application, enterprise_settings_router) # Ensure all routes have auth enabled or are explicitly marked as public diff --git a/backend/ee/danswer/server/query_and_chat/token_limit.py b/backend/ee/danswer/server/query_and_chat/token_limit.py new file mode 100644 index 0000000000..538458fb63 --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/token_limit.py @@ -0,0 +1,184 @@ +from collections import defaultdict +from collections.abc import Sequence +from datetime import datetime +from itertools import groupby +from typing import Dict +from typing import List +from typing import Tuple +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.engine import get_session_context_manager +from danswer.db.models import ChatMessage +from danswer.db.models import ChatSession +from danswer.db.models import TokenRateLimit +from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup +from danswer.server.query_and_chat.token_limit import _get_cutoff_time +from danswer.server.query_and_chat.token_limit import _is_rate_limited +from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from ee.danswer.db.api_key import is_api_key_email_address +from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits + + +def _check_token_rate_limits(user: User | None) -> None: + if user is None: + # Unauthenticated users are only rate limited by global settings + _user_is_rate_limited_by_global() + + elif is_api_key_email_address(user.email): + # API keys are only rate limited by global settings + _user_is_rate_limited_by_global() + + else: + run_functions_tuples_in_parallel( + [ + (_user_is_rate_limited, (user.id,)), + (_user_is_rate_limited_by_group, (user.id,)), + (_user_is_rate_limited_by_global, ()), + ] + ) + + +""" +User rate limits +""" + + +def _user_is_rate_limited(user_id: UUID) -> None: + with get_session_context_manager() as db_session: + user_rate_limits = fetch_all_user_token_rate_limits( + db_session=db_session, enabled_only=True, ordered=False + ) + + if user_rate_limits: + user_cutoff_time = _get_cutoff_time(user_rate_limits) + user_usage = _fetch_user_usage(user_id, user_cutoff_time, db_session) + + if _is_rate_limited(user_rate_limits, user_usage): + raise HTTPException( + status_code=429, + detail="Token budget exceeded for user. Try again later.", + ) + + +def _fetch_user_usage( + user_id: UUID, cutoff_time: datetime, db_session: Session +) -> Sequence[tuple[datetime, int]]: + """ + Fetch user usage within the cutoff time, grouped by minute + """ + result = db_session.execute( + select( + func.date_trunc("minute", ChatMessage.time_sent), + func.sum(ChatMessage.token_count), + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .where(ChatSession.user_id == user_id, ChatMessage.time_sent >= cutoff_time) + .group_by(func.date_trunc("minute", ChatMessage.time_sent)) + ).all() + + return [(row[0], row[1]) for row in result] + + +""" +User Group rate limits +""" + + +def _user_is_rate_limited_by_group(user_id: UUID) -> None: + with get_session_context_manager() as db_session: + group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) + + if group_rate_limits: + # Group cutoff time is the same for all groups. + # This could be optimized to only fetch the maximum cutoff time for + # a specific group, but seems unnecessary for now. + group_cutoff_time = _get_cutoff_time( + [e for sublist in group_rate_limits.values() for e in sublist] + ) + + user_group_ids = list(group_rate_limits.keys()) + group_usage = _fetch_user_group_usage( + user_group_ids, group_cutoff_time, db_session + ) + + has_at_least_one_untriggered_limit = False + for user_group_id, rate_limits in group_rate_limits.items(): + usage = group_usage.get(user_group_id, []) + + if not _is_rate_limited(rate_limits, usage): + has_at_least_one_untriggered_limit = True + break + + if not has_at_least_one_untriggered_limit: + raise HTTPException( + status_code=429, + detail="Token budget exceeded for user's groups. Try again later.", + ) + + +def _fetch_all_user_group_rate_limits( + user_id: UUID, db_session: Session +) -> Dict[int, List[TokenRateLimit]]: + group_limits = ( + select(TokenRateLimit, User__UserGroup.user_group_id) + .join( + TokenRateLimit__UserGroup, + TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, + ) + .join( + UserGroup, + UserGroup.id == TokenRateLimit__UserGroup.user_group_id, + ) + .join( + User__UserGroup, + User__UserGroup.user_group_id == UserGroup.id, + ) + .where( + User__UserGroup.user_id == user_id, + TokenRateLimit.enabled.is_(True), + ) + ) + + raw_rate_limits = db_session.execute(group_limits).all() + + group_rate_limits = defaultdict(list) + for rate_limit, user_group_id in raw_rate_limits: + group_rate_limits[user_group_id].append(rate_limit) + + return group_rate_limits + + +def _fetch_user_group_usage( + user_group_ids: list[int], cutoff_time: datetime, db_session: Session +) -> dict[int, list[Tuple[datetime, int]]]: + """ + Fetch user group usage within the cutoff time, grouped by minute + """ + user_group_usage = db_session.execute( + select( + func.sum(ChatMessage.token_count), + func.date_trunc("minute", ChatMessage.time_sent), + UserGroup.id, + ) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .join(User__UserGroup, User__UserGroup.user_id == ChatSession.user_id) + .join(UserGroup, UserGroup.id == User__UserGroup.user_group_id) + .filter(UserGroup.id.in_(user_group_ids), ChatMessage.time_sent >= cutoff_time) + .group_by(func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id) + ).all() + + return { + user_group_id: [(usage, time_sent) for time_sent, usage, _ in group_usage] + for user_group_id, group_usage in groupby( + user_group_usage, key=lambda row: row[2] + ) + } diff --git a/backend/ee/danswer/server/token_rate_limits/api.py b/backend/ee/danswer/server/token_rate_limits/api.py new file mode 100644 index 0000000000..aac3ebb16c --- /dev/null +++ b/backend/ee/danswer/server/token_rate_limits/api.py @@ -0,0 +1,105 @@ +from collections import defaultdict + +from fastapi import APIRouter +from fastapi import Depends +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.query_and_chat.token_limit import any_rate_limit_exists +from danswer.server.token_rate_limits.models import TokenRateLimitArgs +from danswer.server.token_rate_limits.models import TokenRateLimitDisplay +from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits +from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group +from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits +from ee.danswer.db.token_limit import insert_user_group_token_rate_limit +from ee.danswer.db.token_limit import insert_user_token_rate_limit + +router = APIRouter(prefix="/admin/token-rate-limits") + + +""" +Group Token Limit Settings +""" + + +@router.get("/user-groups") +def get_all_group_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> dict[str, list[TokenRateLimitDisplay]]: + user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group( + db_session + ) + + token_rate_limits_by_group = defaultdict(list) + for token_rate_limit, group_name in user_groups_to_token_rate_limits: + token_rate_limits_by_group[group_name].append( + TokenRateLimitDisplay.from_db(token_rate_limit) + ) + + return dict(token_rate_limits_by_group) + + +@router.get("/user-group/{group_id}") +def get_group_token_limit_settings( + group_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_user_group_token_rate_limits( + db_session, group_id + ) + ] + + +@router.post("/user-group/{group_id}") +def create_group_token_limit_settings( + group_id: int, + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_user_group_token_rate_limit( + db_session=db_session, + token_rate_limit_settings=token_limit_settings, + group_id=group_id, + ) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display + + +""" +User Token Limit Settings +""" + + +@router.get("/users") +def get_user_token_limit_settings( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[TokenRateLimitDisplay]: + return [ + TokenRateLimitDisplay.from_db(token_rate_limit) + for token_rate_limit in fetch_all_user_token_rate_limits(db_session) + ] + + +@router.post("/users") +def create_user_token_limit_settings( + token_limit_settings: TokenRateLimitArgs, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> TokenRateLimitDisplay: + rate_limit_display = TokenRateLimitDisplay.from_db( + insert_user_token_rate_limit(db_session, token_limit_settings) + ) + # clear cache in case this was the first rate limit created + any_rate_limit_exists.cache_clear() + return rate_limit_display diff --git a/web/next.config.js b/web/next.config.js index 166de45093..9783bcd2a8 100644 --- a/web/next.config.js +++ b/web/next.config.js @@ -48,6 +48,11 @@ const nextConfig = { source: "/admin/performance/custom-analytics", destination: "/ee/admin/performance/custom-analytics", }, + // token rate limits + { + source: "/admin/token-rate-limits", + destination: "/ee/admin/token-rate-limits", + }, ] : []; diff --git a/web/src/app/admin/models/llm/page.tsx b/web/src/app/admin/models/llm/page.tsx index 330718ee60..bd449d9391 100644 --- a/web/src/app/admin/models/llm/page.tsx +++ b/web/src/app/admin/models/llm/page.tsx @@ -1,176 +1,9 @@ "use client"; -import { Form, Formik } from "formik"; -import { useEffect, useState } from "react"; import { AdminPageTitle } from "@/components/admin/Title"; -import { - BooleanFormField, - SectionHeader, - TextFormField, -} from "@/components/admin/connectors/Field"; -import { Popup } from "@/components/admin/connectors/Popup"; -import { Button, Divider, Text } from "@tremor/react"; import { FiCpu } from "react-icons/fi"; import { LLMConfiguration } from "./LLMConfiguration"; -const LLMOptions = () => { - const [popup, setPopup] = useState<{ - message: string; - type: "success" | "error"; - } | null>(null); - - const [tokenBudgetGloballyEnabled, setTokenBudgetGloballyEnabled] = - useState(false); - const [initialValues, setInitialValues] = useState({ - enable_token_budget: false, - token_budget: "", - token_budget_time_period: "", - }); - - const fetchConfig = async () => { - const response = await fetch("/api/manage/admin/token-budget-settings"); - if (response.ok) { - const config = await response.json(); - // Assuming the config object directly matches the structure needed for initialValues - setInitialValues({ - enable_token_budget: config.enable_token_budget || false, - token_budget: config.token_budget || "", - token_budget_time_period: config.token_budget_time_period || "", - }); - setTokenBudgetGloballyEnabled(true); - } else { - // Handle error or provide fallback values - setPopup({ - message: "Failed to load current LLM options.", - type: "error", - }); - } - }; - - // Fetch current config when the component mounts - useEffect(() => { - fetchConfig(); - }, []); - - if (!tokenBudgetGloballyEnabled) { - return null; - } - - return ( - <> - {popup && } - { - const response = await fetch( - "/api/manage/admin/token-budget-settings", - { - method: "PUT", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(values), - } - ); - if (response.ok) { - setPopup({ - message: "Updated LLM Options", - type: "success", - }); - await fetchConfig(); - } else { - const body = await response.json(); - if (body.detail) { - setPopup({ message: body.detail, type: "error" }); - } else { - setPopup({ - message: "Unable to update LLM options.", - type: "error", - }); - } - setTimeout(() => { - setPopup(null); - }, 4000); - } - }} - > - {({ isSubmitting, values, setFieldValue }) => { - return ( -
- - <> - Token Budget - - Set a maximum token use per time period. If the token budget - is exceeded, Danswer will not be able to respond to queries - until the next time period. - -
- { - setFieldValue("enable_token_budget", e.target.checked); - }} - /> - {values.enable_token_budget && ( - <> - - How many tokens (in thousands) can be used per time - period? If unspecified, no limit will be set. - - } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if (value === "" || /^[0-9]+$/.test(value)) { - setFieldValue("token_budget", value); - } - }} - /> - - Specify the length of the time period, in hours, over - which the token budget will be applied. - - } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if (value === "" || /^[0-9]+$/.test(value)) { - setFieldValue("token_budget_time_period", value); - } - }} - /> - - )} - -
- -
- - ); - }} -
- - ); -}; - const Page = () => { return (
@@ -181,7 +14,6 @@ const Page = () => { -
); }; diff --git a/web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx b/web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx new file mode 100644 index 0000000000..2d5f32ec33 --- /dev/null +++ b/web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx @@ -0,0 +1,175 @@ +"use client"; + +import * as Yup from "yup"; +import { Button } from "@tremor/react"; +import { useEffect, useState } from "react"; +import { Modal } from "@/components/Modal"; +import { Form, Formik } from "formik"; +import { + SelectorFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { UserGroup } from "@/lib/types"; +import { Scope } from "./types"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; + +interface CreateRateLimitModalProps { + isOpen: boolean; + setIsOpen: (isOpen: boolean) => void; + onSubmit: ( + target_scope: Scope, + period_hours: number, + token_budget: number, + group_id: number + ) => void; + setPopup: (popupSpec: PopupSpec | null) => void; + forSpecificScope?: Scope; + forSpecificUserGroup?: number; +} + +export const CreateRateLimitModal = ({ + isOpen, + setIsOpen, + onSubmit, + setPopup, + forSpecificScope, + forSpecificUserGroup, +}: CreateRateLimitModalProps) => { + const [modalUserGroups, setModalUserGroups] = useState([]); + const [shouldFetchUserGroups, setShouldFetchUserGroups] = useState( + forSpecificScope === Scope.USER_GROUP + ); + + useEffect(() => { + const fetchData = async () => { + try { + const response = await fetch("/api/manage/admin/user-group"); + const data = await response.json(); + const options = data.map((userGroup: UserGroup) => ({ + name: userGroup.name, + value: userGroup.id, + })); + setModalUserGroups(options); + setShouldFetchUserGroups(false); + } catch (error) { + setPopup({ + type: "error", + message: `Failed to fetch user groups: ${error}`, + }); + } + }; + + if (shouldFetchUserGroups) { + fetchData(); + } + }, [shouldFetchUserGroups, setPopup]); + + if (!isOpen) { + return null; + } + + return ( + setIsOpen(false)} + width="w-2/6" + > + { + return ( + context.parent.target_scope !== "user_group" || + (context.parent.target_scope === "user_group" && + value !== undefined) + ); + } + ), + })} + onSubmit={async (values, formikHelpers) => { + formikHelpers.setSubmitting(true); + onSubmit( + values.target_scope, + Number(values.period_hours), + Number(values.token_budget), + Number(values.user_group_id) + ); + return formikHelpers.setSubmitting(false); + }} + > + {({ isSubmitting, values, setFieldValue }) => ( +
+ {!forSpecificScope && ( + { + setFieldValue("target_scope", selected) + if (selected === Scope.USER_GROUP) { + setShouldFetchUserGroups(true); + } + }} + /> + )} + {forSpecificUserGroup === undefined && + values.target_scope === Scope.USER_GROUP && ( + + )} + + +
+ +
+ + )} +
+
+ ); +}; diff --git a/web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx b/web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx new file mode 100644 index 0000000000..71e550a1c4 --- /dev/null +++ b/web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx @@ -0,0 +1,169 @@ +"use client"; + +import { + Table, + TableHead, + TableRow, + TableHeaderCell, + TableBody, + TableCell, + Title, + Text, +} from "@tremor/react"; +import { DeleteButton } from "@/components/DeleteButton"; +import { deleteTokenRateLimit, updateTokenRateLimit } from "./lib"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { TokenRateLimitDisplay } from "./types"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import useSWR, { mutate } from "swr"; +import { CustomCheckbox } from "@/components/CustomCheckbox"; + +type TokenRateLimitTableArgs = { + tokenRateLimits: TokenRateLimitDisplay[]; + title?: string; + description?: string; + fetchUrl: string; + hideHeading?: boolean; +}; + +export const TokenRateLimitTable = ({ + tokenRateLimits, + title, + description, + fetchUrl, + hideHeading, +}: TokenRateLimitTableArgs) => { + const shouldRenderGroupName = () => + tokenRateLimits.length > 0 && tokenRateLimits[0].group_name !== undefined; + + const handleEnabledChange = (id: number) => { + const tokenRateLimit = tokenRateLimits.find( + (tokenRateLimit) => tokenRateLimit.token_id === id + ); + + if (!tokenRateLimit) { + return; + } + + updateTokenRateLimit(id, { + token_budget: tokenRateLimit.token_budget, + period_hours: tokenRateLimit.period_hours, + enabled: !tokenRateLimit.enabled, + }).then(() => { + mutate(fetchUrl); + }); + }; + + const handleDelete = (id: number) => + deleteTokenRateLimit(id).then(() => { + mutate(fetchUrl); + }); + + if (tokenRateLimits.length === 0) { + return ( +
+ {!hideHeading && title && {title}} + {!hideHeading && description && ( + {description} + )} + + No token rate limits set! + +
+ ); + } + + return ( +
+ {!hideHeading && title && {title}} + {!hideHeading && description && ( + {description} + )} + + + + Enabled + {shouldRenderGroupName() && ( + Group Name + )} + Time Window (Hours) + Token Budget (Thousands) + Delete + + + + {tokenRateLimits.map((tokenRateLimit) => { + return ( + + +
handleEnabledChange(tokenRateLimit.token_id)} + className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-24 flex" + > +
+ +

+ {tokenRateLimit.enabled ? "Enabled" : "Disabled"} +

+
+
+
+ {shouldRenderGroupName() && ( + + {tokenRateLimit.group_name} + + )} + {tokenRateLimit.period_hours} + {tokenRateLimit.token_budget} + + handleDelete(tokenRateLimit.token_id)} + /> + +
+ ); + })} +
+
+
+ ); +}; + +export const GenericTokenRateLimitTable = ({ + fetchUrl, + title, + description, + hideHeading, + responseMapper, +}: { + fetchUrl: string; + title?: string; + description?: string; + hideHeading?: boolean; + responseMapper?: (data: any) => TokenRateLimitDisplay[]; +}) => { + const { data, isLoading, error } = useSWR(fetchUrl, errorHandlingFetcher); + + if (isLoading) { + return ; + } + + if (!isLoading && error) { + return Failed to load token rate limits; + } + + let processedData = data; + if (responseMapper) { + processedData = responseMapper(data); + } + + return ( + + ); +}; diff --git a/web/src/app/admin/token-rate-limits/lib.ts b/web/src/app/admin/token-rate-limits/lib.ts new file mode 100644 index 0000000000..f63db82c82 --- /dev/null +++ b/web/src/app/admin/token-rate-limits/lib.ts @@ -0,0 +1,64 @@ +import { TokenRateLimitArgs } from "./types"; + +const API_PREFIX = "/api/admin/token-rate-limits"; + +// Global Token Limits +export const insertGlobalTokenRateLimit = async ( + tokenRateLimit: TokenRateLimitArgs +) => { + return await fetch(`${API_PREFIX}/global`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(tokenRateLimit), + }); +}; + +// User Token Limits +export const insertUserTokenRateLimit = async ( + tokenRateLimit: TokenRateLimitArgs +) => { + return await fetch(`${API_PREFIX}/users`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(tokenRateLimit), + }); +}; + +// User Group Token Limits (EE Only) +export const insertGroupTokenRateLimit = async ( + tokenRateLimit: TokenRateLimitArgs, + group_id: number +) => { + return await fetch(`${API_PREFIX}/user-group/${group_id}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(tokenRateLimit), + }); +}; + +// Common Endpoints + +export const deleteTokenRateLimit = async (token_rate_limit_id: number) => { + return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, { + method: "DELETE", + }); +}; + +export const updateTokenRateLimit = async ( + token_rate_limit_id: number, + tokenRateLimit: TokenRateLimitArgs +) => { + return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(tokenRateLimit), + }); +}; diff --git a/web/src/app/admin/token-rate-limits/page.tsx b/web/src/app/admin/token-rate-limits/page.tsx new file mode 100644 index 0000000000..5982b37710 --- /dev/null +++ b/web/src/app/admin/token-rate-limits/page.tsx @@ -0,0 +1,223 @@ +"use client"; + +import { AdminPageTitle } from "@/components/admin/Title"; +import { + Button, + Tab, + TabGroup, + TabList, + TabPanel, + TabPanels, + Text, +} from "@tremor/react"; +import { useState } from "react"; +import { FiGlobe, FiShield, FiUser, FiUsers } from "react-icons/fi"; +import { + insertGlobalTokenRateLimit, + insertGroupTokenRateLimit, + insertUserTokenRateLimit, +} from "./lib"; +import { Scope, TokenRateLimit } from "./types"; +import { GenericTokenRateLimitTable } from "./TokenRateLimitTables"; +import { mutate } from "swr"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { CreateRateLimitModal } from "./CreateRateLimitModal"; +import { EE_ENABLED } from "@/lib/constants"; + +const BASE_URL = "/api/admin/token-rate-limits"; +const GLOBAL_TOKEN_FETCH_URL = `${BASE_URL}/global`; +const USER_TOKEN_FETCH_URL = `${BASE_URL}/users`; +const USER_GROUP_FETCH_URL = `${BASE_URL}/user-groups`; + +const GLOBAL_DESCRIPTION = + "Global rate limits apply to all users, user groups, and API keys. When the global \ + rate limit is reached, no more tokens can be spent."; +const USER_DESCRIPTION = + "User rate limits apply to individual users. When a user reaches a limit, they will \ + be temporarily blocked from spending tokens."; +const USER_GROUP_DESCRIPTION = + "User group rate limits apply to all users in a group. When a group reaches a limit, \ + all users in the group will be temporarily blocked from spending tokens, regardless \ + of their individual limits. If a user is in multiple groups, the most lenient limit \ + will apply."; + +const handleCreateTokenRateLimit = async ( + target_scope: Scope, + period_hours: number, + token_budget: number, + group_id: number = -1 +) => { + const tokenRateLimitArgs = { + enabled: true, + token_budget: token_budget, + period_hours: period_hours, + }; + + if (target_scope === Scope.GLOBAL) { + return await insertGlobalTokenRateLimit(tokenRateLimitArgs); + } else if (target_scope === Scope.USER) { + return await insertUserTokenRateLimit(tokenRateLimitArgs); + } else if (target_scope === Scope.USER_GROUP) { + return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id); + } else { + throw new Error(`Invalid target_scope: ${target_scope}`); + } +}; + +function Main() { + const [tabIndex, setTabIndex] = useState(0); + const [modalIsOpen, setModalIsOpen] = useState(false); + const { popup, setPopup } = usePopup(); + + const updateTable = (target_scope: Scope) => { + if (target_scope === Scope.GLOBAL) { + mutate(GLOBAL_TOKEN_FETCH_URL); + setTabIndex(0); + } else if (target_scope === Scope.USER) { + mutate(USER_TOKEN_FETCH_URL); + setTabIndex(1); + } else if (target_scope === Scope.USER_GROUP) { + mutate(USER_GROUP_FETCH_URL); + setTabIndex(2); + } + }; + + const handleSubmit = ( + target_scope: Scope, + period_hours: number, + token_budget: number, + group_id: number = -1 + ) => { + handleCreateTokenRateLimit( + target_scope, + period_hours, + token_budget, + group_id + ) + .then(() => { + setModalIsOpen(false); + setPopup({ type: "success", message: "Token rate limit created!" }); + updateTable(target_scope); + }) + .catch((error) => { + setPopup({ type: "error", message: error.message }); + }); + }; + + return ( +
+ {popup} + + + Token rate limits enable you control how many tokens can be spent in a + given time period. With token rate limits, you can: + + +
    +
  • + + Set a global rate limit to control your organization's overall + token spend. + +
  • + {EE_ENABLED && ( + <> +
  • + + Set rate limits for users to ensure that no single user can + spend too many tokens. + +
  • +
  • + + Set rate limits for user groups to control token spend for your + teams. + +
  • + + )} +
  • + Enable and disable rate limits on the fly. +
  • +
+ + + + {EE_ENABLED && ( + + + Global + User + User Groups + + + + + + + + + + ) => + Object.entries(data).flatMap(([group_name, elements]) => + elements.map((element) => ({ + ...element, + group_name, + })) + ) + } + /> + + + + )} + + {!EE_ENABLED && ( +
+ +
+ )} + + setModalIsOpen(false)} + setPopup={setPopup} + onSubmit={handleSubmit} + forSpecificScope={EE_ENABLED ? undefined : Scope.GLOBAL} + /> +
+ ); +} + +export default function Page() { + return ( +
+ } /> + +
+
+ ); +} diff --git a/web/src/app/admin/token-rate-limits/types.ts b/web/src/app/admin/token-rate-limits/types.ts new file mode 100644 index 0000000000..8ea457915b --- /dev/null +++ b/web/src/app/admin/token-rate-limits/types.ts @@ -0,0 +1,22 @@ +export enum Scope { + USER = "user", + USER_GROUP = "user_group", + GLOBAL = "global", +} + +export interface TokenRateLimitArgs { + enabled: boolean; + token_budget: number; + period_hours: number; +} + +export interface TokenRateLimit { + token_id: number; + enabled: boolean; + token_budget: number; + period_hours: number; +} + +export interface TokenRateLimitDisplay extends TokenRateLimit { + group_name?: string; +} diff --git a/web/src/app/ee/admin/groups/[groupId]/AddTokenRateLimitForm.tsx b/web/src/app/ee/admin/groups/[groupId]/AddTokenRateLimitForm.tsx new file mode 100644 index 0000000000..37497bdbbe --- /dev/null +++ b/web/src/app/ee/admin/groups/[groupId]/AddTokenRateLimitForm.tsx @@ -0,0 +1,60 @@ +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import { CreateRateLimitModal } from "../../../../admin/token-rate-limits/CreateRateLimitModal"; +import { Scope } from "../../../../admin/token-rate-limits/types"; +import { insertGroupTokenRateLimit } from "../../../../admin/token-rate-limits/lib"; +import { mutate } from "swr"; + +interface AddMemberFormProps { + isOpen: boolean; + setIsOpen: (isOpen: boolean) => void; + setPopup: (popupSpec: PopupSpec | null) => void; + userGroupId: number; +} + +const handleCreateGroupTokenRateLimit = async ( + period_hours: number, + token_budget: number, + group_id: number = -1 +) => { + const tokenRateLimitArgs = { + enabled: true, + token_budget: token_budget, + period_hours: period_hours, + }; + return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id); +}; + +export const AddTokenRateLimitForm: React.FC = ({ + isOpen, + setIsOpen, + setPopup, + userGroupId, +}) => { + const handleSubmit = ( + _: Scope, + period_hours: number, + token_budget: number, + group_id: number = -1 + ) => { + handleCreateGroupTokenRateLimit(period_hours, token_budget, group_id) + .then(() => { + setIsOpen(false); + setPopup({ type: "success", message: "Token rate limit created!" }); + mutate(`/api/admin/token-rate-limits/user-group/${userGroupId}`); + }) + .catch((error) => { + setPopup({ type: "error", message: error.message }); + }); + }; + + return ( + + ); +}; diff --git a/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx b/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx index bb3b9f5252..c361c0cdee 100644 --- a/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx +++ b/web/src/app/ee/admin/groups/[groupId]/GroupDisplay.tsx @@ -22,6 +22,8 @@ import { import { DeleteButton } from "@/components/DeleteButton"; import { Bubble } from "@/components/Bubble"; import { BookmarkIcon, RobotIcon } from "@/components/icons/icons"; +import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm"; +import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables"; interface GroupDisplayProps { users: User[]; @@ -39,6 +41,7 @@ export const GroupDisplay = ({ const { popup, setPopup } = usePopup(); const [addMemberFormVisible, setAddMemberFormVisible] = useState(false); const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false); + const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false); return (
@@ -301,6 +304,31 @@ export const GroupDisplay = ({ )}
+ + + +

Token Rate Limits

+ + + + + + ); }; diff --git a/web/src/components/admin/Layout.tsx b/web/src/components/admin/Layout.tsx index ebea06e175..1009fe91b9 100644 --- a/web/src/components/admin/Layout.tsx +++ b/web/src/components/admin/Layout.tsx @@ -28,6 +28,7 @@ import { FiImage, FiPackage, FiSettings, + FiShield, FiSlack, FiTool, } from "react-icons/fi"; @@ -215,6 +216,15 @@ export async function Layout({ children }: { children: React.ReactNode }) { }, ] : []), + { + name: ( +
+ +
Token Rate Limits
+
+ ), + link: "/admin/token-rate-limits", + }, ], }, ...(EE_ENABLED