2024-12-13 09:56:10 -08:00

110 lines
4.1 KiB
Python

from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_embedding_providers
from onyx.db.llm import remove_embedding_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.models import User
from onyx.db.search_settings import get_all_search_settings
from onyx.db.search_settings import get_current_db_embedding_provider
from onyx.indexing.models import EmbeddingModelDetail
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.embedding.models import TestEmbeddingRequest
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/embedding")
basic_router = APIRouter(prefix="/embedding")
@admin_router.post("/test-embedding")
def test_embedding_configuration(
test_llm_request: TestEmbeddingRequest,
_: User | None = Depends(current_admin_user),
) -> None:
try:
test_model = EmbeddingModel(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name,
api_version=test_llm_request.api_version,
deployment_name=test_llm_request.deployment_name,
normalize=False,
query_prefix=None,
passage_prefix=None,
)
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("", response_model=list[EmbeddingModelDetail])
def list_embedding_models(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[EmbeddingModelDetail]:
search_settings = get_all_search_settings(db_session)
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
@admin_router.get("/embedding-provider")
def list_embedding_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[CloudEmbeddingProvider]:
return [
CloudEmbeddingProvider.from_request(embedding_provider_model)
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
]
@admin_router.delete("/embedding-provider/{provider_type}")
def delete_embedding_provider(
provider_type: EmbeddingProvider,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
if (
embedding_provider is not None
and provider_type == embedding_provider.provider_type
):
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, provider_type=provider_type)
@admin_router.put("/embedding-provider")
def put_cloud_embedding_provider(
provider: CloudEmbeddingProviderCreationRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CloudEmbeddingProvider:
return upsert_cloud_embedding_provider(db_session, provider)