mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-24 15:00:57 +02:00
Token Rate Limiting
WIP Cleanup 🧹 Remove existing rate limiting logic Cleanup 🧼 Undo nit Cleanup 🧽 Move db constants (avoids circular import) WIP WIP Cleanup Lint Resolve alembic conflict Fix mypy Add backfill to migration Update comment Make unauthenticated users still adhere to global limits Use Depends Remove enum from table Update migration error handling + deletion Address verbal feedback, cleanup urls, minor nits
This commit is contained in:
parent
7a408749cf
commit
d7a704c0d9
@ -41,10 +41,6 @@ DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
TOKEN_BUDGET = "token_budget"
|
||||
TOKEN_BUDGET_TIME_PERIOD = "token_budget_time_period"
|
||||
ENABLE_TOKEN_BUDGET = "enable_token_budget"
|
||||
TOKEN_BUDGET_SETTINGS = "token_budget_settings"
|
||||
|
||||
# For chunking/processing chunks
|
||||
TITLE_SEPARATOR = "\n\r\n"
|
||||
|
@ -83,6 +83,9 @@ from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@ -281,6 +284,9 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
@ -1,23 +1,16 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
|
||||
from danswer.configs.constants import TOKEN_BUDGET
|
||||
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
|
||||
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_session
|
||||
@ -193,41 +186,3 @@ def create_deletion_attempt_for_connector_id(
|
||||
file_store = get_default_file_store(db_session)
|
||||
for file_name in connector.connector_specific_config["file_locations"]:
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
|
||||
@router.get("/admin/token-budget-settings")
|
||||
def get_token_budget_settings(_: User = Depends(current_admin_user)) -> dict:
|
||||
if not TOKEN_BUDGET_GLOBALLY_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Token budget is not enabled in the application."
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS)
|
||||
)
|
||||
settings = json.loads(settings_json)
|
||||
return settings
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Token budget settings not found.")
|
||||
|
||||
|
||||
@router.put("/admin/token-budget-settings")
|
||||
def update_token_budget_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
enable_token_budget: bool = Body(..., embed=True),
|
||||
token_budget: int = Body(..., ge=0, embed=True), # Ensure non-negative
|
||||
token_budget_time_period: int = Body(..., ge=1, embed=True), # Ensure positive
|
||||
) -> dict[str, str]:
|
||||
# Prepare the settings as a JSON string
|
||||
settings_json = json.dumps(
|
||||
{
|
||||
ENABLE_TOKEN_BUDGET: enable_token_budget,
|
||||
TOKEN_BUDGET: token_budget,
|
||||
TOKEN_BUDGET_TIME_PERIOD: token_budget_time_period,
|
||||
}
|
||||
)
|
||||
|
||||
# Store the settings in the dynamic config store
|
||||
get_dynamic_config_store().store(TOKEN_BUDGET_SETTINGS, settings_json)
|
||||
return {"message": "Token budget settings updated successfully."}
|
||||
|
@ -64,6 +64,7 @@ from danswer.server.query_and_chat.models import PromptOverride
|
||||
from danswer.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from danswer.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -275,6 +276,7 @@ def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
"""This endpoint is both used for all the following purposes:
|
||||
- Sending a new message in the session
|
||||
|
@ -29,7 +29,7 @@ from danswer.server.query_and_chat.models import QueryValidationResponse
|
||||
from danswer.server.query_and_chat.models import SimpleQueryRequest
|
||||
from danswer.server.query_and_chat.models import SourceTag
|
||||
from danswer.server.query_and_chat.models import TagResponse
|
||||
from danswer.server.query_and_chat.token_budget import check_token_budget
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -150,7 +150,7 @@ def stream_query_validation(
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_user),
|
||||
_: bool = Depends(check_token_budget),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
logger.info(f"Received query for one shot answer with quotes: {query}")
|
||||
|
@ -1,79 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
|
||||
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
|
||||
from danswer.configs.constants import TOKEN_BUDGET
|
||||
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
|
||||
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
|
||||
BUDGET_LIMIT_DEFAULT = -1 # Default to no limit
|
||||
TIME_PERIOD_HOURS_DEFAULT = 12
|
||||
|
||||
|
||||
def is_under_token_budget(db_session: Session) -> bool:
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load(TOKEN_BUDGET_SETTINGS)
|
||||
)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get(ENABLE_TOKEN_BUDGET, False)
|
||||
|
||||
if not is_enabled:
|
||||
return True
|
||||
|
||||
budget_limit = settings.get(TOKEN_BUDGET, -1)
|
||||
|
||||
if budget_limit < 0:
|
||||
return True
|
||||
|
||||
period_hours = settings.get(TOKEN_BUDGET_TIME_PERIOD, TIME_PERIOD_HOURS_DEFAULT)
|
||||
period_start_time = datetime.now() - timedelta(hours=period_hours)
|
||||
|
||||
# Fetch the sum of all tokens used within the period
|
||||
token_sum = (
|
||||
db_session.query(func.sum(ChatMessage.token_count))
|
||||
.filter(ChatMessage.time_sent >= period_start_time)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
print(
|
||||
"token_sum:",
|
||||
token_sum,
|
||||
"budget_limit:",
|
||||
budget_limit,
|
||||
"period_hours:",
|
||||
period_hours,
|
||||
"period_start_time:",
|
||||
period_start_time,
|
||||
)
|
||||
|
||||
return token_sum < (
|
||||
budget_limit * 1000
|
||||
) # Budget limit is expressed in thousands of tokens
|
||||
|
||||
|
||||
def check_token_budget() -> None:
|
||||
if not TOKEN_BUDGET_GLOBALLY_ENABLED:
|
||||
return None
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
# Perform the token budget check here, possibly using `user` and `db_session` for database access if needed
|
||||
if not is_under_token_budget(db_session):
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Sorry, token budget exceeded. Try again later."
|
||||
)
|
135
backend/danswer/server/query_and_chat/token_limit.py
Normal file
135
backend/danswer/server/query_and_chat/token_limit.py
Normal file
@ -0,0 +1,135 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from functools import lru_cache
|
||||
|
||||
from dateutil import tz
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TOKEN_BUDGET_UNIT = 1_000
|
||||
|
||||
|
||||
def check_token_rate_limits(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
if not any_rate_limit_exists():
|
||||
return
|
||||
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
)
|
||||
return versioned_rate_limit_strategy(user)
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None) -> None:
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
|
||||
"""
|
||||
Global rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_global() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
||||
if global_rate_limits:
|
||||
global_cutoff_time = _get_cutoff_time(global_rate_limits)
|
||||
global_usage = _fetch_global_usage(global_cutoff_time, db_session)
|
||||
|
||||
if _is_rate_limited(global_rate_limits, global_usage):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Token budget exceeded for organization. Try again later.",
|
||||
)
|
||||
|
||||
|
||||
def _fetch_global_usage(
|
||||
cutoff_time: datetime, db_session: Session
|
||||
) -> Sequence[tuple[datetime, int]]:
|
||||
"""
|
||||
Fetch global token usage within the cutoff time, grouped by minute
|
||||
"""
|
||||
result = db_session.execute(
|
||||
select(
|
||||
func.date_trunc("minute", ChatMessage.time_sent),
|
||||
func.sum(ChatMessage.token_count),
|
||||
)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.filter(
|
||||
ChatMessage.time_sent >= cutoff_time,
|
||||
)
|
||||
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
|
||||
).all()
|
||||
|
||||
return [(row[0], row[1]) for row in result]
|
||||
|
||||
|
||||
"""
|
||||
Common functions
|
||||
"""
|
||||
|
||||
|
||||
def _get_cutoff_time(rate_limits: Sequence[TokenRateLimit]) -> datetime:
|
||||
max_period_hours = max(rate_limit.period_hours for rate_limit in rate_limits)
|
||||
return datetime.now(tz=timezone.utc) - timedelta(hours=max_period_hours)
|
||||
|
||||
|
||||
def _is_rate_limited(
|
||||
rate_limits: Sequence[TokenRateLimit], usage: Sequence[tuple[datetime, int]]
|
||||
) -> bool:
|
||||
"""
|
||||
If at least one rate limit is exceeded, return True
|
||||
"""
|
||||
for rate_limit in rate_limits:
|
||||
tokens_used = sum(
|
||||
u_token_count
|
||||
for u_date, u_token_count in usage
|
||||
if u_date
|
||||
>= datetime.now(tz=tz.UTC) - timedelta(hours=rate_limit.period_hours)
|
||||
)
|
||||
|
||||
if tokens_used >= rate_limit.token_budget * TOKEN_BUDGET_UNIT:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def any_rate_limit_exists() -> bool:
|
||||
"""Checks if any rate limit exists in the database. Is cached, so that if no rate limits
|
||||
are setup, we don't have any effect on average query latency."""
|
||||
logger.info("Checking for any rate limits...")
|
||||
with get_session_context_manager() as db_session:
|
||||
return (
|
||||
db_session.scalar(
|
||||
select(TokenRateLimit.id).where(
|
||||
TokenRateLimit.enabled == True # noqa: E712
|
||||
)
|
||||
)
|
||||
is not None
|
||||
)
|
79
backend/danswer/server/token_rate_limits/api.py
Normal file
79
backend/danswer/server/token_rate_limits/api.py
Normal file
@ -0,0 +1,79 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from ee.danswer.db.token_limit import delete_token_rate_limit
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from ee.danswer.db.token_limit import insert_global_token_rate_limit
|
||||
from ee.danswer.db.token_limit import update_token_rate_limit
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
"""
|
||||
Global Token Limit Settings
|
||||
"""
|
||||
|
||||
|
||||
@router.get("/global")
|
||||
def get_global_token_limit_settings(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_all_global_token_rate_limits(db_session)
|
||||
]
|
||||
|
||||
|
||||
@router.post("/global")
|
||||
def create_global_token_limit_settings(
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
insert_global_token_rate_limit(db_session, token_limit_settings)
|
||||
)
|
||||
# clear cache in case this was the first rate limit created
|
||||
any_rate_limit_exists.cache_clear()
|
||||
return rate_limit_display
|
||||
|
||||
|
||||
"""
|
||||
General Token Limit Settings
|
||||
"""
|
||||
|
||||
|
||||
@router.put("/rate-limit/{token_rate_limit_id}")
|
||||
def update_token_limit_settings(
|
||||
token_rate_limit_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
return TokenRateLimitDisplay.from_db(
|
||||
update_token_rate_limit(
|
||||
db_session=db_session,
|
||||
token_rate_limit_id=token_rate_limit_id,
|
||||
token_rate_limit_settings=token_limit_settings,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/rate-limit/{token_rate_limit_id}")
|
||||
def delete_token_limit_settings(
|
||||
token_rate_limit_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
return delete_token_rate_limit(
|
||||
db_session=db_session,
|
||||
token_rate_limit_id=token_rate_limit_id,
|
||||
)
|
25
backend/danswer/server/token_rate_limits/models.py
Normal file
25
backend/danswer/server/token_rate_limits/models.py
Normal file
@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import TokenRateLimit
|
||||
|
||||
|
||||
class TokenRateLimitArgs(BaseModel):
|
||||
enabled: bool
|
||||
token_budget: int
|
||||
period_hours: int
|
||||
|
||||
|
||||
class TokenRateLimitDisplay(BaseModel):
|
||||
token_id: int
|
||||
enabled: bool
|
||||
token_budget: int
|
||||
period_hours: int
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, token_rate_limit: TokenRateLimit) -> "TokenRateLimitDisplay":
|
||||
return cls(
|
||||
token_id=token_rate_limit.id,
|
||||
enabled=token_rate_limit.enabled,
|
||||
token_budget=token_rate_limit.token_budget,
|
||||
period_hours=token_rate_limit.period_hours,
|
||||
)
|
176
backend/ee/danswer/db/token_limit.py
Normal file
176
backend/ee/danswer/db/token_limit.py
Normal file
@ -0,0 +1,176 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import Row
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
|
||||
|
||||
def fetch_all_user_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.USER
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
return db_session.scalars(query).all()
|
||||
|
||||
|
||||
def fetch_all_global_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
token_rate_limits = db_session.scalars(query).all()
|
||||
return token_rate_limits
|
||||
|
||||
|
||||
def fetch_all_user_group_token_rate_limits(
|
||||
db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(
|
||||
TokenRateLimit__UserGroup.user_group_id == group_id,
|
||||
TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP,
|
||||
)
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
token_rate_limits = db_session.scalars(query).all()
|
||||
return token_rate_limits
|
||||
|
||||
|
||||
def fetch_all_user_group_token_rate_limits_by_group(
|
||||
db_session: Session,
|
||||
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
|
||||
query = (
|
||||
select(TokenRateLimit, UserGroup.name)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id)
|
||||
)
|
||||
|
||||
return db_session.execute(query).all()
|
||||
|
||||
|
||||
def insert_user_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.USER,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_global_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.GLOBAL,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_user_group_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
group_id: int,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.USER_GROUP,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.flush()
|
||||
|
||||
rate_limit = TokenRateLimit__UserGroup(
|
||||
rate_limit_id=token_limit.id, user_group_id=group_id
|
||||
)
|
||||
db_session.add(rate_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def update_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
token_limit.enabled = token_rate_limit_settings.enabled
|
||||
token_limit.token_budget = token_rate_limit_settings.token_budget
|
||||
token_limit.period_hours = token_rate_limit_settings.period_hours
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def delete_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
) -> None:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
db_session.query(TokenRateLimit__UserGroup).filter(
|
||||
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
|
||||
).delete()
|
||||
|
||||
db_session.delete(token_limit)
|
||||
db_session.commit()
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
@ -241,6 +242,21 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
token_rate_limit__user_group_relationships = db_session.scalars(
|
||||
select(TokenRateLimit__UserGroup).where(
|
||||
TokenRateLimit__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
).all()
|
||||
for (
|
||||
token_rate_limit__user_group_relationship
|
||||
) in token_rate_limit__user_group_relationships:
|
||||
db_session.delete(token_rate_limit__user_group_relationship)
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
@ -255,6 +271,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
@ -34,6 +34,9 @@ from ee.danswer.server.query_and_chat.query_backend import (
|
||||
)
|
||||
from ee.danswer.server.query_history.api import router as query_history_router
|
||||
from ee.danswer.server.saml import router as saml_router
|
||||
from ee.danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from ee.danswer.server.user_group.api import router as user_group_router
|
||||
|
||||
logger = setup_logger()
|
||||
@ -85,6 +88,10 @@ def get_ee_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, enterprise_settings_admin_router
|
||||
)
|
||||
# Token rate limit settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
|
184
backend/ee/danswer/server/query_and_chat/token_limit.py
Normal file
184
backend/ee/danswer/server/query_and_chat/token_limit.py
Normal file
@ -0,0 +1,184 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from itertools import groupby
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.server.query_and_chat.token_limit import _get_cutoff_time
|
||||
from danswer.server.query_and_chat.token_limit import _is_rate_limited
|
||||
from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
||||
if user_rate_limits:
|
||||
user_cutoff_time = _get_cutoff_time(user_rate_limits)
|
||||
user_usage = _fetch_user_usage(user_id, user_cutoff_time, db_session)
|
||||
|
||||
if _is_rate_limited(user_rate_limits, user_usage):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Token budget exceeded for user. Try again later.",
|
||||
)
|
||||
|
||||
|
||||
def _fetch_user_usage(
|
||||
user_id: UUID, cutoff_time: datetime, db_session: Session
|
||||
) -> Sequence[tuple[datetime, int]]:
|
||||
"""
|
||||
Fetch user usage within the cutoff time, grouped by minute
|
||||
"""
|
||||
result = db_session.execute(
|
||||
select(
|
||||
func.date_trunc("minute", ChatMessage.time_sent),
|
||||
func.sum(ChatMessage.token_count),
|
||||
)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatSession.user_id == user_id, ChatMessage.time_sent >= cutoff_time)
|
||||
.group_by(func.date_trunc("minute", ChatMessage.time_sent))
|
||||
).all()
|
||||
|
||||
return [(row[0], row[1]) for row in result]
|
||||
|
||||
|
||||
"""
|
||||
User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
# Group cutoff time is the same for all groups.
|
||||
# This could be optimized to only fetch the maximum cutoff time for
|
||||
# a specific group, but seems unnecessary for now.
|
||||
group_cutoff_time = _get_cutoff_time(
|
||||
[e for sublist in group_rate_limits.values() for e in sublist]
|
||||
)
|
||||
|
||||
user_group_ids = list(group_rate_limits.keys())
|
||||
group_usage = _fetch_user_group_usage(
|
||||
user_group_ids, group_cutoff_time, db_session
|
||||
)
|
||||
|
||||
has_at_least_one_untriggered_limit = False
|
||||
for user_group_id, rate_limits in group_rate_limits.items():
|
||||
usage = group_usage.get(user_group_id, [])
|
||||
|
||||
if not _is_rate_limited(rate_limits, usage):
|
||||
has_at_least_one_untriggered_limit = True
|
||||
break
|
||||
|
||||
if not has_at_least_one_untriggered_limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Token budget exceeded for user's groups. Try again later.",
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_user_group_rate_limits(
|
||||
user_id: UUID, db_session: Session
|
||||
) -> Dict[int, List[TokenRateLimit]]:
|
||||
group_limits = (
|
||||
select(TokenRateLimit, User__UserGroup.user_group_id)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.join(
|
||||
UserGroup,
|
||||
UserGroup.id == TokenRateLimit__UserGroup.user_group_id,
|
||||
)
|
||||
.join(
|
||||
User__UserGroup,
|
||||
User__UserGroup.user_group_id == UserGroup.id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id == user_id,
|
||||
TokenRateLimit.enabled.is_(True),
|
||||
)
|
||||
)
|
||||
|
||||
raw_rate_limits = db_session.execute(group_limits).all()
|
||||
|
||||
group_rate_limits = defaultdict(list)
|
||||
for rate_limit, user_group_id in raw_rate_limits:
|
||||
group_rate_limits[user_group_id].append(rate_limit)
|
||||
|
||||
return group_rate_limits
|
||||
|
||||
|
||||
def _fetch_user_group_usage(
|
||||
user_group_ids: list[int], cutoff_time: datetime, db_session: Session
|
||||
) -> dict[int, list[Tuple[datetime, int]]]:
|
||||
"""
|
||||
Fetch user group usage within the cutoff time, grouped by minute
|
||||
"""
|
||||
user_group_usage = db_session.execute(
|
||||
select(
|
||||
func.sum(ChatMessage.token_count),
|
||||
func.date_trunc("minute", ChatMessage.time_sent),
|
||||
UserGroup.id,
|
||||
)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.join(User__UserGroup, User__UserGroup.user_id == ChatSession.user_id)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.filter(UserGroup.id.in_(user_group_ids), ChatMessage.time_sent >= cutoff_time)
|
||||
.group_by(func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id)
|
||||
).all()
|
||||
|
||||
return {
|
||||
user_group_id: [(usage, time_sent) for time_sent, usage, _ in group_usage]
|
||||
for user_group_id, group_usage in groupby(
|
||||
user_group_usage, key=lambda row: row[2]
|
||||
)
|
||||
}
|
105
backend/ee/danswer/server/token_rate_limits/api.py
Normal file
105
backend/ee/danswer/server/token_rate_limits/api.py
Normal file
@ -0,0 +1,105 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits
|
||||
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
|
||||
from ee.danswer.db.token_limit import insert_user_token_rate_limit
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
"""
|
||||
Group Token Limit Settings
|
||||
"""
|
||||
|
||||
|
||||
@router.get("/user-groups")
|
||||
def get_all_group_token_limit_settings(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, list[TokenRateLimitDisplay]]:
|
||||
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
|
||||
db_session
|
||||
)
|
||||
|
||||
token_rate_limits_by_group = defaultdict(list)
|
||||
for token_rate_limit, group_name in user_groups_to_token_rate_limits:
|
||||
token_rate_limits_by_group[group_name].append(
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
)
|
||||
|
||||
return dict(token_rate_limits_by_group)
|
||||
|
||||
|
||||
@router.get("/user-group/{group_id}")
|
||||
def get_group_token_limit_settings(
|
||||
group_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_all_user_group_token_rate_limits(
|
||||
db_session, group_id
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@router.post("/user-group/{group_id}")
|
||||
def create_group_token_limit_settings(
|
||||
group_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
insert_user_group_token_rate_limit(
|
||||
db_session=db_session,
|
||||
token_rate_limit_settings=token_limit_settings,
|
||||
group_id=group_id,
|
||||
)
|
||||
)
|
||||
# clear cache in case this was the first rate limit created
|
||||
any_rate_limit_exists.cache_clear()
|
||||
return rate_limit_display
|
||||
|
||||
|
||||
"""
|
||||
User Token Limit Settings
|
||||
"""
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
def get_user_token_limit_settings(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
TokenRateLimitDisplay.from_db(token_rate_limit)
|
||||
for token_rate_limit in fetch_all_user_token_rate_limits(db_session)
|
||||
]
|
||||
|
||||
|
||||
@router.post("/users")
|
||||
def create_user_token_limit_settings(
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
insert_user_token_rate_limit(db_session, token_limit_settings)
|
||||
)
|
||||
# clear cache in case this was the first rate limit created
|
||||
any_rate_limit_exists.cache_clear()
|
||||
return rate_limit_display
|
@ -48,6 +48,11 @@ const nextConfig = {
|
||||
source: "/admin/performance/custom-analytics",
|
||||
destination: "/ee/admin/performance/custom-analytics",
|
||||
},
|
||||
// token rate limits
|
||||
{
|
||||
source: "/admin/token-rate-limits",
|
||||
destination: "/ee/admin/token-rate-limits",
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
|
@ -1,176 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { Form, Formik } from "formik";
|
||||
import { useEffect, useState } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
BooleanFormField,
|
||||
SectionHeader,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Popup } from "@/components/admin/connectors/Popup";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { FiCpu } from "react-icons/fi";
|
||||
import { LLMConfiguration } from "./LLMConfiguration";
|
||||
|
||||
const LLMOptions = () => {
|
||||
const [popup, setPopup] = useState<{
|
||||
message: string;
|
||||
type: "success" | "error";
|
||||
} | null>(null);
|
||||
|
||||
const [tokenBudgetGloballyEnabled, setTokenBudgetGloballyEnabled] =
|
||||
useState(false);
|
||||
const [initialValues, setInitialValues] = useState({
|
||||
enable_token_budget: false,
|
||||
token_budget: "",
|
||||
token_budget_time_period: "",
|
||||
});
|
||||
|
||||
const fetchConfig = async () => {
|
||||
const response = await fetch("/api/manage/admin/token-budget-settings");
|
||||
if (response.ok) {
|
||||
const config = await response.json();
|
||||
// Assuming the config object directly matches the structure needed for initialValues
|
||||
setInitialValues({
|
||||
enable_token_budget: config.enable_token_budget || false,
|
||||
token_budget: config.token_budget || "",
|
||||
token_budget_time_period: config.token_budget_time_period || "",
|
||||
});
|
||||
setTokenBudgetGloballyEnabled(true);
|
||||
} else {
|
||||
// Handle error or provide fallback values
|
||||
setPopup({
|
||||
message: "Failed to load current LLM options.",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Fetch current config when the component mounts
|
||||
useEffect(() => {
|
||||
fetchConfig();
|
||||
}, []);
|
||||
|
||||
if (!tokenBudgetGloballyEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup && <Popup message={popup.message} type={popup.type} />}
|
||||
<Formik
|
||||
enableReinitialize={true}
|
||||
initialValues={initialValues}
|
||||
onSubmit={async (values) => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/token-budget-settings",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(values),
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Updated LLM Options",
|
||||
type: "success",
|
||||
});
|
||||
await fetchConfig();
|
||||
} else {
|
||||
const body = await response.json();
|
||||
if (body.detail) {
|
||||
setPopup({ message: body.detail, type: "error" });
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Unable to update LLM options.",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
setTimeout(() => {
|
||||
setPopup(null);
|
||||
}, 4000);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => {
|
||||
return (
|
||||
<Form>
|
||||
<Divider />
|
||||
<>
|
||||
<SectionHeader>Token Budget</SectionHeader>
|
||||
<Text>
|
||||
Set a maximum token use per time period. If the token budget
|
||||
is exceeded, Danswer will not be able to respond to queries
|
||||
until the next time period.
|
||||
</Text>
|
||||
<br />
|
||||
<BooleanFormField
|
||||
name="enable_token_budget"
|
||||
label="Enable Token Budget"
|
||||
subtext="If enabled, Danswer will be limited to the token budget specified below."
|
||||
onChange={(e) => {
|
||||
setFieldValue("enable_token_budget", e.target.checked);
|
||||
}}
|
||||
/>
|
||||
{values.enable_token_budget && (
|
||||
<>
|
||||
<TextFormField
|
||||
name="token_budget"
|
||||
label="Token Budget"
|
||||
subtext={
|
||||
<div>
|
||||
How many tokens (in thousands) can be used per time
|
||||
period? If unspecified, no limit will be set.
|
||||
</div>
|
||||
}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
// Allow only integer values
|
||||
if (value === "" || /^[0-9]+$/.test(value)) {
|
||||
setFieldValue("token_budget", value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<TextFormField
|
||||
name="token_budget_time_period"
|
||||
label="Token Budget Time Period (hours)"
|
||||
subtext={
|
||||
<div>
|
||||
Specify the length of the time period, in hours, over
|
||||
which the token budget will be applied.
|
||||
</div>
|
||||
}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
// Allow only integer values
|
||||
if (value === "" || /^[0-9]+$/.test(value)) {
|
||||
setFieldValue("token_budget_time_period", value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
<div className="flex">
|
||||
<Button
|
||||
className="w-64 mx-auto"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Submit
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
);
|
||||
}}
|
||||
</Formik>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const Page = () => {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
@ -181,7 +14,6 @@ const Page = () => {
|
||||
|
||||
<LLMConfiguration />
|
||||
|
||||
<LLMOptions />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
175
web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx
Normal file
175
web/src/app/admin/token-rate-limits/CreateRateLimitModal.tsx
Normal file
@ -0,0 +1,175 @@
|
||||
"use client";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { Button } from "@tremor/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Form, Formik } from "formik";
|
||||
import {
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { UserGroup } from "@/lib/types";
|
||||
import { Scope } from "./types";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
|
||||
interface CreateRateLimitModalProps {
|
||||
isOpen: boolean;
|
||||
setIsOpen: (isOpen: boolean) => void;
|
||||
onSubmit: (
|
||||
target_scope: Scope,
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number
|
||||
) => void;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
forSpecificScope?: Scope;
|
||||
forSpecificUserGroup?: number;
|
||||
}
|
||||
|
||||
export const CreateRateLimitModal = ({
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
onSubmit,
|
||||
setPopup,
|
||||
forSpecificScope,
|
||||
forSpecificUserGroup,
|
||||
}: CreateRateLimitModalProps) => {
|
||||
const [modalUserGroups, setModalUserGroups] = useState([]);
|
||||
const [shouldFetchUserGroups, setShouldFetchUserGroups] = useState(
|
||||
forSpecificScope === Scope.USER_GROUP
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/manage/admin/user-group");
|
||||
const data = await response.json();
|
||||
const options = data.map((userGroup: UserGroup) => ({
|
||||
name: userGroup.name,
|
||||
value: userGroup.id,
|
||||
}));
|
||||
setModalUserGroups(options);
|
||||
setShouldFetchUserGroups(false);
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to fetch user groups: ${error}`,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (shouldFetchUserGroups) {
|
||||
fetchData();
|
||||
}
|
||||
}, [shouldFetchUserGroups, setPopup]);
|
||||
|
||||
if (!isOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={"Create a Token Rate Limit"}
|
||||
onOutsideClick={() => setIsOpen(false)}
|
||||
width="w-2/6"
|
||||
>
|
||||
<Formik
|
||||
initialValues={{
|
||||
enabled: true,
|
||||
period_hours: "",
|
||||
token_budget: "",
|
||||
target_scope: forSpecificScope || Scope.GLOBAL,
|
||||
user_group_id: forSpecificUserGroup,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
period_hours: Yup.number()
|
||||
.required("Time Window is a required field")
|
||||
.min(1, "Time Window must be at least 1 hour"),
|
||||
token_budget: Yup.number()
|
||||
.required("Token Budget is a required field")
|
||||
.min(1, "Token Budget must be at least 1"),
|
||||
target_scope: Yup.string().required(
|
||||
"Target Scope is a required field"
|
||||
),
|
||||
user_group_id: Yup.string().test(
|
||||
"user_group_id",
|
||||
"User Group is a required field",
|
||||
(value, context) => {
|
||||
return (
|
||||
context.parent.target_scope !== "user_group" ||
|
||||
(context.parent.target_scope === "user_group" &&
|
||||
value !== undefined)
|
||||
);
|
||||
}
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
onSubmit(
|
||||
values.target_scope,
|
||||
Number(values.period_hours),
|
||||
Number(values.token_budget),
|
||||
Number(values.user_group_id)
|
||||
);
|
||||
return formikHelpers.setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form>
|
||||
{!forSpecificScope && (
|
||||
<SelectorFormField
|
||||
name="target_scope"
|
||||
label="Target Scope"
|
||||
options={[
|
||||
{ name: "Global", value: Scope.GLOBAL },
|
||||
{ name: "User", value: Scope.USER },
|
||||
{ name: "User Group", value: Scope.USER_GROUP },
|
||||
]}
|
||||
includeDefault={false}
|
||||
onSelect={(selected) => {
|
||||
setFieldValue("target_scope", selected)
|
||||
if (selected === Scope.USER_GROUP) {
|
||||
setShouldFetchUserGroups(true);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{forSpecificUserGroup === undefined &&
|
||||
values.target_scope === Scope.USER_GROUP && (
|
||||
<SelectorFormField
|
||||
name="user_group_id"
|
||||
label="User Group"
|
||||
options={modalUserGroups}
|
||||
includeDefault={false}
|
||||
/>
|
||||
)}
|
||||
<TextFormField
|
||||
name="period_hours"
|
||||
label="Time Window (Hours)"
|
||||
type="number"
|
||||
placeholder=""
|
||||
/>
|
||||
<TextFormField
|
||||
name="token_budget"
|
||||
label="Token Budget (Thousands)"
|
||||
type="number"
|
||||
placeholder=""
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button
|
||||
type="submit"
|
||||
size="xs"
|
||||
color="green"
|
||||
disabled={isSubmitting}
|
||||
className="mx-auto w-64"
|
||||
>
|
||||
Create!
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal>
|
||||
);
|
||||
};
|
169
web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx
Normal file
169
web/src/app/admin/token-rate-limits/TokenRateLimitTables.tsx
Normal file
@ -0,0 +1,169 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableHeaderCell,
|
||||
TableBody,
|
||||
TableCell,
|
||||
Title,
|
||||
Text,
|
||||
} from "@tremor/react";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import { deleteTokenRateLimit, updateTokenRateLimit } from "./lib";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { TokenRateLimitDisplay } from "./types";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { CustomCheckbox } from "@/components/CustomCheckbox";
|
||||
|
||||
type TokenRateLimitTableArgs = {
|
||||
tokenRateLimits: TokenRateLimitDisplay[];
|
||||
title?: string;
|
||||
description?: string;
|
||||
fetchUrl: string;
|
||||
hideHeading?: boolean;
|
||||
};
|
||||
|
||||
export const TokenRateLimitTable = ({
|
||||
tokenRateLimits,
|
||||
title,
|
||||
description,
|
||||
fetchUrl,
|
||||
hideHeading,
|
||||
}: TokenRateLimitTableArgs) => {
|
||||
const shouldRenderGroupName = () =>
|
||||
tokenRateLimits.length > 0 && tokenRateLimits[0].group_name !== undefined;
|
||||
|
||||
const handleEnabledChange = (id: number) => {
|
||||
const tokenRateLimit = tokenRateLimits.find(
|
||||
(tokenRateLimit) => tokenRateLimit.token_id === id
|
||||
);
|
||||
|
||||
if (!tokenRateLimit) {
|
||||
return;
|
||||
}
|
||||
|
||||
updateTokenRateLimit(id, {
|
||||
token_budget: tokenRateLimit.token_budget,
|
||||
period_hours: tokenRateLimit.period_hours,
|
||||
enabled: !tokenRateLimit.enabled,
|
||||
}).then(() => {
|
||||
mutate(fetchUrl);
|
||||
});
|
||||
};
|
||||
|
||||
const handleDelete = (id: number) =>
|
||||
deleteTokenRateLimit(id).then(() => {
|
||||
mutate(fetchUrl);
|
||||
});
|
||||
|
||||
if (tokenRateLimits.length === 0) {
|
||||
return (
|
||||
<div>
|
||||
{!hideHeading && title && <Title>{title}</Title>}
|
||||
{!hideHeading && description && (
|
||||
<Text className="my-2">{description}</Text>
|
||||
)}
|
||||
<Text className={`${!hideHeading && "my-8"}`}>
|
||||
No token rate limits set!
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{!hideHeading && title && <Title>{title}</Title>}
|
||||
{!hideHeading && description && (
|
||||
<Text className="my-2">{description}</Text>
|
||||
)}
|
||||
<Table className={`overflow-visible ${!hideHeading && "my-8"}`}>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Enabled</TableHeaderCell>
|
||||
{shouldRenderGroupName() && (
|
||||
<TableHeaderCell>Group Name</TableHeaderCell>
|
||||
)}
|
||||
<TableHeaderCell>Time Window (Hours)</TableHeaderCell>
|
||||
<TableHeaderCell>Token Budget (Thousands)</TableHeaderCell>
|
||||
<TableHeaderCell>Delete</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{tokenRateLimits.map((tokenRateLimit) => {
|
||||
return (
|
||||
<TableRow key={tokenRateLimit.token_id}>
|
||||
<TableCell>
|
||||
<div
|
||||
onClick={() => handleEnabledChange(tokenRateLimit.token_id)}
|
||||
className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-24 flex"
|
||||
>
|
||||
<div className="mx-auto flex">
|
||||
<CustomCheckbox checked={tokenRateLimit.enabled} />
|
||||
<p className="ml-2">
|
||||
{tokenRateLimit.enabled ? "Enabled" : "Disabled"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</TableCell>
|
||||
{shouldRenderGroupName() && (
|
||||
<TableCell className="font-bold text-emphasis">
|
||||
{tokenRateLimit.group_name}
|
||||
</TableCell>
|
||||
)}
|
||||
<TableCell>{tokenRateLimit.period_hours}</TableCell>
|
||||
<TableCell>{tokenRateLimit.token_budget}</TableCell>
|
||||
<TableCell>
|
||||
<DeleteButton
|
||||
onClick={() => handleDelete(tokenRateLimit.token_id)}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const GenericTokenRateLimitTable = ({
|
||||
fetchUrl,
|
||||
title,
|
||||
description,
|
||||
hideHeading,
|
||||
responseMapper,
|
||||
}: {
|
||||
fetchUrl: string;
|
||||
title?: string;
|
||||
description?: string;
|
||||
hideHeading?: boolean;
|
||||
responseMapper?: (data: any) => TokenRateLimitDisplay[];
|
||||
}) => {
|
||||
const { data, isLoading, error } = useSWR(fetchUrl, errorHandlingFetcher);
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (!isLoading && error) {
|
||||
return <Text>Failed to load token rate limits</Text>;
|
||||
}
|
||||
|
||||
let processedData = data;
|
||||
if (responseMapper) {
|
||||
processedData = responseMapper(data);
|
||||
}
|
||||
|
||||
return (
|
||||
<TokenRateLimitTable
|
||||
tokenRateLimits={processedData}
|
||||
fetchUrl={fetchUrl}
|
||||
title={title}
|
||||
description={description}
|
||||
hideHeading={hideHeading}
|
||||
/>
|
||||
);
|
||||
};
|
64
web/src/app/admin/token-rate-limits/lib.ts
Normal file
64
web/src/app/admin/token-rate-limits/lib.ts
Normal file
@ -0,0 +1,64 @@
|
||||
import { TokenRateLimitArgs } from "./types";
|
||||
|
||||
const API_PREFIX = "/api/admin/token-rate-limits";
|
||||
|
||||
// Global Token Limits
|
||||
export const insertGlobalTokenRateLimit = async (
|
||||
tokenRateLimit: TokenRateLimitArgs
|
||||
) => {
|
||||
return await fetch(`${API_PREFIX}/global`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(tokenRateLimit),
|
||||
});
|
||||
};
|
||||
|
||||
// User Token Limits
|
||||
export const insertUserTokenRateLimit = async (
|
||||
tokenRateLimit: TokenRateLimitArgs
|
||||
) => {
|
||||
return await fetch(`${API_PREFIX}/users`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(tokenRateLimit),
|
||||
});
|
||||
};
|
||||
|
||||
// User Group Token Limits (EE Only)
|
||||
export const insertGroupTokenRateLimit = async (
|
||||
tokenRateLimit: TokenRateLimitArgs,
|
||||
group_id: number
|
||||
) => {
|
||||
return await fetch(`${API_PREFIX}/user-group/${group_id}`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(tokenRateLimit),
|
||||
});
|
||||
};
|
||||
|
||||
// Common Endpoints
|
||||
|
||||
export const deleteTokenRateLimit = async (token_rate_limit_id: number) => {
|
||||
return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
};
|
||||
|
||||
export const updateTokenRateLimit = async (
|
||||
token_rate_limit_id: number,
|
||||
tokenRateLimit: TokenRateLimitArgs
|
||||
) => {
|
||||
return await fetch(`${API_PREFIX}/rate-limit/${token_rate_limit_id}`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(tokenRateLimit),
|
||||
});
|
||||
};
|
223
web/src/app/admin/token-rate-limits/page.tsx
Normal file
223
web/src/app/admin/token-rate-limits/page.tsx
Normal file
@ -0,0 +1,223 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
Button,
|
||||
Tab,
|
||||
TabGroup,
|
||||
TabList,
|
||||
TabPanel,
|
||||
TabPanels,
|
||||
Text,
|
||||
} from "@tremor/react";
|
||||
import { useState } from "react";
|
||||
import { FiGlobe, FiShield, FiUser, FiUsers } from "react-icons/fi";
|
||||
import {
|
||||
insertGlobalTokenRateLimit,
|
||||
insertGroupTokenRateLimit,
|
||||
insertUserTokenRateLimit,
|
||||
} from "./lib";
|
||||
import { Scope, TokenRateLimit } from "./types";
|
||||
import { GenericTokenRateLimitTable } from "./TokenRateLimitTables";
|
||||
import { mutate } from "swr";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { CreateRateLimitModal } from "./CreateRateLimitModal";
|
||||
import { EE_ENABLED } from "@/lib/constants";
|
||||
|
||||
const BASE_URL = "/api/admin/token-rate-limits";
|
||||
const GLOBAL_TOKEN_FETCH_URL = `${BASE_URL}/global`;
|
||||
const USER_TOKEN_FETCH_URL = `${BASE_URL}/users`;
|
||||
const USER_GROUP_FETCH_URL = `${BASE_URL}/user-groups`;
|
||||
|
||||
const GLOBAL_DESCRIPTION =
|
||||
"Global rate limits apply to all users, user groups, and API keys. When the global \
|
||||
rate limit is reached, no more tokens can be spent.";
|
||||
const USER_DESCRIPTION =
|
||||
"User rate limits apply to individual users. When a user reaches a limit, they will \
|
||||
be temporarily blocked from spending tokens.";
|
||||
const USER_GROUP_DESCRIPTION =
|
||||
"User group rate limits apply to all users in a group. When a group reaches a limit, \
|
||||
all users in the group will be temporarily blocked from spending tokens, regardless \
|
||||
of their individual limits. If a user is in multiple groups, the most lenient limit \
|
||||
will apply.";
|
||||
|
||||
const handleCreateTokenRateLimit = async (
|
||||
target_scope: Scope,
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
const tokenRateLimitArgs = {
|
||||
enabled: true,
|
||||
token_budget: token_budget,
|
||||
period_hours: period_hours,
|
||||
};
|
||||
|
||||
if (target_scope === Scope.GLOBAL) {
|
||||
return await insertGlobalTokenRateLimit(tokenRateLimitArgs);
|
||||
} else if (target_scope === Scope.USER) {
|
||||
return await insertUserTokenRateLimit(tokenRateLimitArgs);
|
||||
} else if (target_scope === Scope.USER_GROUP) {
|
||||
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
|
||||
} else {
|
||||
throw new Error(`Invalid target_scope: ${target_scope}`);
|
||||
}
|
||||
};
|
||||
|
||||
function Main() {
|
||||
const [tabIndex, setTabIndex] = useState(0);
|
||||
const [modalIsOpen, setModalIsOpen] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const updateTable = (target_scope: Scope) => {
|
||||
if (target_scope === Scope.GLOBAL) {
|
||||
mutate(GLOBAL_TOKEN_FETCH_URL);
|
||||
setTabIndex(0);
|
||||
} else if (target_scope === Scope.USER) {
|
||||
mutate(USER_TOKEN_FETCH_URL);
|
||||
setTabIndex(1);
|
||||
} else if (target_scope === Scope.USER_GROUP) {
|
||||
mutate(USER_GROUP_FETCH_URL);
|
||||
setTabIndex(2);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = (
|
||||
target_scope: Scope,
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
handleCreateTokenRateLimit(
|
||||
target_scope,
|
||||
period_hours,
|
||||
token_budget,
|
||||
group_id
|
||||
)
|
||||
.then(() => {
|
||||
setModalIsOpen(false);
|
||||
setPopup({ type: "success", message: "Token rate limit created!" });
|
||||
updateTable(target_scope);
|
||||
})
|
||||
.catch((error) => {
|
||||
setPopup({ type: "error", message: error.message });
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
|
||||
<Text className="mb-2">
|
||||
Token rate limits enable you control how many tokens can be spent in a
|
||||
given time period. With token rate limits, you can:
|
||||
</Text>
|
||||
|
||||
<ul className="list-disc mt-2 ml-4 mb-2">
|
||||
<li>
|
||||
<Text>
|
||||
Set a global rate limit to control your organization's overall
|
||||
token spend.
|
||||
</Text>
|
||||
</li>
|
||||
{EE_ENABLED && (
|
||||
<>
|
||||
<li>
|
||||
<Text>
|
||||
Set rate limits for users to ensure that no single user can
|
||||
spend too many tokens.
|
||||
</Text>
|
||||
</li>
|
||||
<li>
|
||||
<Text>
|
||||
Set rate limits for user groups to control token spend for your
|
||||
teams.
|
||||
</Text>
|
||||
</li>
|
||||
</>
|
||||
)}
|
||||
<li>
|
||||
<Text>Enable and disable rate limits on the fly.</Text>
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<Button
|
||||
color="green"
|
||||
size="xs"
|
||||
className="mt-3"
|
||||
onClick={() => setModalIsOpen(true)}
|
||||
>
|
||||
Create a Token Rate Limit
|
||||
</Button>
|
||||
|
||||
{EE_ENABLED && (
|
||||
<TabGroup className="mt-6" index={tabIndex} onIndexChange={setTabIndex}>
|
||||
<TabList variant="line">
|
||||
<Tab icon={FiGlobe}>Global</Tab>
|
||||
<Tab icon={FiUser}>User</Tab>
|
||||
<Tab icon={FiUsers}>User Groups</Tab>
|
||||
</TabList>
|
||||
<TabPanels className="mt-6">
|
||||
<TabPanel>
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={GLOBAL_TOKEN_FETCH_URL}
|
||||
title={"Global Token Rate Limits"}
|
||||
description={GLOBAL_DESCRIPTION}
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={USER_TOKEN_FETCH_URL}
|
||||
title={"User Token Rate Limits"}
|
||||
description={USER_DESCRIPTION}
|
||||
/>
|
||||
</TabPanel>
|
||||
<TabPanel>
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={USER_GROUP_FETCH_URL}
|
||||
title={"User Group Token Rate Limits"}
|
||||
description={USER_GROUP_DESCRIPTION}
|
||||
responseMapper={(data: Record<string, TokenRateLimit[]>) =>
|
||||
Object.entries(data).flatMap(([group_name, elements]) =>
|
||||
elements.map((element) => ({
|
||||
...element,
|
||||
group_name,
|
||||
}))
|
||||
)
|
||||
}
|
||||
/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
)}
|
||||
|
||||
{!EE_ENABLED && (
|
||||
<div className="mt-6">
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={GLOBAL_TOKEN_FETCH_URL}
|
||||
title={"Global Token Rate Limits"}
|
||||
description={GLOBAL_DESCRIPTION}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<CreateRateLimitModal
|
||||
isOpen={modalIsOpen}
|
||||
setIsOpen={() => setModalIsOpen(false)}
|
||||
setPopup={setPopup}
|
||||
onSubmit={handleSubmit}
|
||||
forSpecificScope={EE_ENABLED ? undefined : Scope.GLOBAL}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle title="Token Rate Limits" icon={<FiShield size={32} />} />
|
||||
|
||||
<Main />
|
||||
</div>
|
||||
);
|
||||
}
|
22
web/src/app/admin/token-rate-limits/types.ts
Normal file
22
web/src/app/admin/token-rate-limits/types.ts
Normal file
@ -0,0 +1,22 @@
|
||||
export enum Scope {
|
||||
USER = "user",
|
||||
USER_GROUP = "user_group",
|
||||
GLOBAL = "global",
|
||||
}
|
||||
|
||||
export interface TokenRateLimitArgs {
|
||||
enabled: boolean;
|
||||
token_budget: number;
|
||||
period_hours: number;
|
||||
}
|
||||
|
||||
export interface TokenRateLimit {
|
||||
token_id: number;
|
||||
enabled: boolean;
|
||||
token_budget: number;
|
||||
period_hours: number;
|
||||
}
|
||||
|
||||
export interface TokenRateLimitDisplay extends TokenRateLimit {
|
||||
group_name?: string;
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CreateRateLimitModal } from "../../../../admin/token-rate-limits/CreateRateLimitModal";
|
||||
import { Scope } from "../../../../admin/token-rate-limits/types";
|
||||
import { insertGroupTokenRateLimit } from "../../../../admin/token-rate-limits/lib";
|
||||
import { mutate } from "swr";
|
||||
|
||||
interface AddMemberFormProps {
|
||||
isOpen: boolean;
|
||||
setIsOpen: (isOpen: boolean) => void;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
userGroupId: number;
|
||||
}
|
||||
|
||||
const handleCreateGroupTokenRateLimit = async (
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
const tokenRateLimitArgs = {
|
||||
enabled: true,
|
||||
token_budget: token_budget,
|
||||
period_hours: period_hours,
|
||||
};
|
||||
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
|
||||
};
|
||||
|
||||
export const AddTokenRateLimitForm: React.FC<AddMemberFormProps> = ({
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
setPopup,
|
||||
userGroupId,
|
||||
}) => {
|
||||
const handleSubmit = (
|
||||
_: Scope,
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
handleCreateGroupTokenRateLimit(period_hours, token_budget, group_id)
|
||||
.then(() => {
|
||||
setIsOpen(false);
|
||||
setPopup({ type: "success", message: "Token rate limit created!" });
|
||||
mutate(`/api/admin/token-rate-limits/user-group/${userGroupId}`);
|
||||
})
|
||||
.catch((error) => {
|
||||
setPopup({ type: "error", message: error.message });
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<CreateRateLimitModal
|
||||
isOpen={isOpen}
|
||||
setIsOpen={setIsOpen}
|
||||
onSubmit={handleSubmit}
|
||||
setPopup={setPopup}
|
||||
forSpecificScope={Scope.USER_GROUP}
|
||||
forSpecificUserGroup={userGroupId}
|
||||
/>
|
||||
);
|
||||
};
|
@ -22,6 +22,8 @@ import {
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { BookmarkIcon, RobotIcon } from "@/components/icons/icons";
|
||||
import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm";
|
||||
import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables";
|
||||
|
||||
interface GroupDisplayProps {
|
||||
users: User[];
|
||||
@ -39,6 +41,7 @@ export const GroupDisplay = ({
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [addMemberFormVisible, setAddMemberFormVisible] = useState(false);
|
||||
const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false);
|
||||
const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false);
|
||||
|
||||
return (
|
||||
<div>
|
||||
@ -301,6 +304,31 @@ export const GroupDisplay = ({
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Divider />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Token Rate Limits</h2>
|
||||
|
||||
<AddTokenRateLimitForm
|
||||
isOpen={addRateLimitFormVisible}
|
||||
setIsOpen={setAddRateLimitFormVisible}
|
||||
setPopup={setPopup}
|
||||
userGroupId={userGroup.id}
|
||||
/>
|
||||
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={`/api/admin/token-rate-limits/user-group/${userGroup.id}`}
|
||||
hideHeading
|
||||
/>
|
||||
|
||||
<Button
|
||||
color="green"
|
||||
size="xs"
|
||||
className="mt-3"
|
||||
onClick={() => setAddRateLimitFormVisible(true)}
|
||||
>
|
||||
Create a Token Rate Limit
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
@ -28,6 +28,7 @@ import {
|
||||
FiImage,
|
||||
FiPackage,
|
||||
FiSettings,
|
||||
FiShield,
|
||||
FiSlack,
|
||||
FiTool,
|
||||
} from "react-icons/fi";
|
||||
@ -215,6 +216,15 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<FiShield size={18} />
|
||||
<div className="ml-1">Token Rate Limits</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/token-rate-limits",
|
||||
},
|
||||
],
|
||||
},
|
||||
...(EE_ENABLED
|
||||
|
Loading…
x
Reference in New Issue
Block a user