2024-12-31 13:04:02 -05:00

199 lines
6.9 KiB
Python

from collections.abc import Callable
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import User
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/llm")
basic_router = APIRouter(prefix="/llm")
@admin_router.get("/built-in/options")
def fetch_llm_options(
_: User | None = Depends(current_admin_user),
) -> list[WellKnownLLMProviderDescriptor]:
return fetch_available_well_known_llms()
@admin_router.post("/test")
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
) -> None:
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
deployment_name=test_llm_request.deployment_name,
)
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if (
test_llm_request.fast_default_model_name
and test_llm_request.fast_default_model_name
!= test_llm_request.default_model_name
):
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_llm_request.api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
deployment_name=test_llm_request.deployment_name,
)
functions_with_args.append((test_llm, (fast_llm,)))
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
client_error_msg = litellm_exception_to_error_msg(
error, llm, fallback_to_error_msg=True
)
raise HTTPException(status_code=400, detail=client_error_msg)
@admin_router.post("/test/default")
def test_default_provider(
_: User | None = Depends(current_admin_user),
) -> None:
try:
llm, fast_llm = get_default_llms()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
functions_with_args: list[tuple[Callable, tuple]] = [
(test_llm, (llm,)),
(test_llm, (fast_llm,)),
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
raise HTTPException(status_code=400, detail=error)
@admin_router.get("/provider")
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
False,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
remove_llm_provider(db_session, provider_id)
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
"""Endpoints for all"""
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
]