rkuo-danswer 24184024bb
Bugfix/dependency updates (#4482)
* bump fastapi and starlette

* bumping llama index and nltk and associated deps

* bump to fix python-multipart

* bump aiohttp

* update package lock for examples/widget

* bump black

* sentencesplitter has changed namespaces

* fix reorder import check, fix missing passlib

* update package-lock.json

* black formatter updated

* reformatted again

* change to black compatible reorder

* change to black compatible reorder-python-imports fork

* fix pytest dependency

* black format again

* we don't need cdk.txt. update packages to be consistent across all packages

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-04-10 08:23:02 +00:00

356 lines
13 KiB
Python

from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import User
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/llm")
basic_router = APIRouter(prefix="/llm")
@admin_router.get("/built-in/options")
def fetch_llm_options(
_: User | None = Depends(current_admin_user),
) -> list[WellKnownLLMProviderDescriptor]:
return fetch_available_well_known_llms()
@admin_router.post("/test")
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
test_api_key = test_llm_request.api_key
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
test_llm_request.name, db_session
)
if existing_provider:
test_api_key = existing_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
deployment_name=test_llm_request.deployment_name,
)
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if (
test_llm_request.fast_default_model_name
and test_llm_request.fast_default_model_name
!= test_llm_request.default_model_name
):
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
deployment_name=test_llm_request.deployment_name,
)
functions_with_args.append((test_llm, (fast_llm,)))
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
client_error_msg = litellm_exception_to_error_msg(
error, llm, fallback_to_error_msg=True
)
raise HTTPException(status_code=400, detail=client_error_msg)
@admin_router.post("/test/default")
def test_default_provider(
_: User | None = Depends(current_admin_user),
) -> None:
try:
llm, fast_llm = get_default_llms()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
functions_with_args: list[tuple[Callable, tuple]] = [
(test_llm, (llm,)),
(test_llm, (fast_llm,)),
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
raise HTTPException(status_code=400, detail=error)
@admin_router.get("/provider")
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(db_session):
from_model_start = datetime.now(timezone.utc)
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
from_model_end = datetime.now(timezone.utc)
from_model_duration = (from_model_end - from_model_start).total_seconds()
logger.debug(
f"LLMProviderView.from_model took {from_model_duration:.2f} seconds"
)
if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
False,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
if llm_provider.display_model_names is not None:
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name
not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(
llm_provider.fast_default_model_name
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
remove_llm_provider(db_session, provider_id)
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
)
@admin_router.get("/vision-providers")
def get_vision_capable_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
providers = fetch_existing_llm_providers(db_session)
vision_providers = []
logger.info("Fetching vision-capable providers")
for provider in providers:
vision_models = []
# Check model names in priority order
model_names_to_check = []
if provider.model_names:
model_names_to_check = provider.model_names
elif provider.display_model_names:
model_names_to_check = provider.display_model_names
elif provider.default_model_name:
model_names_to_check = [provider.default_model_name]
# Check each model for vision capability
for model_name in model_names_to_check:
if model_supports_image_input(model_name, provider.provider):
vision_models.append(model_name)
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = LLMProviderView.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"
)
vision_providers.append(VisionProviderResponse(**provider_dict))
logger.info(f"Found {len(vision_providers)} vision-capable providers")
return vision_providers
"""Endpoints for all"""
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch basic LLM providers for user")
llm_provider_list: list[LLMProviderDescriptor] = []
for llm_provider_model in fetch_existing_llm_providers_for_user(db_session, user):
from_model_start = datetime.now(timezone.utc)
full_llm_provider = LLMProviderDescriptor.from_model(llm_provider_model)
from_model_end = datetime.now(timezone.utc)
from_model_duration = (from_model_end - from_model_start).total_seconds()
logger.debug(
f"LLMProviderView.from_model took {from_model_duration:.2f} seconds"
)
llm_provider_list.append(full_llm_provider)
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching basic LLM providers in {duration:.2f} seconds")
return llm_provider_list
@admin_router.get("/provider-contextual-cost")
def get_provider_contextual_cost(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMCost]:
"""
Get the cost of Re-indexing all documents for contextual retrieval.
See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token
This includes:
- The cost of invoking the LLM on each chunk-document pair to get
- the doc_summary
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
costs = []
for provider in providers:
for model_name in provider.display_model_names or provider.model_names or []:
llm = get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(provider=provider.name, model_name=model_name, cost=cost)
)
return costs