mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-16 06:40:15 +02:00
110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.llm import fetch_existing_embedding_providers
|
|
from onyx.db.llm import remove_embedding_provider
|
|
from onyx.db.llm import upsert_cloud_embedding_provider
|
|
from onyx.db.models import User
|
|
from onyx.db.search_settings import get_all_search_settings
|
|
from onyx.db.search_settings import get_current_db_embedding_provider
|
|
from onyx.indexing.models import EmbeddingModelDetail
|
|
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
|
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
|
from onyx.server.manage.embedding.models import TestEmbeddingRequest
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import MODEL_SERVER_HOST
|
|
from shared_configs.configs import MODEL_SERVER_PORT
|
|
from shared_configs.enums import EmbeddingProvider
|
|
from shared_configs.enums import EmbedTextType
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
admin_router = APIRouter(prefix="/admin/embedding")
|
|
basic_router = APIRouter(prefix="/embedding")
|
|
|
|
|
|
@admin_router.post("/test-embedding")
|
|
def test_embedding_configuration(
|
|
test_llm_request: TestEmbeddingRequest,
|
|
_: User | None = Depends(current_admin_user),
|
|
) -> None:
|
|
try:
|
|
test_model = EmbeddingModel(
|
|
server_host=MODEL_SERVER_HOST,
|
|
server_port=MODEL_SERVER_PORT,
|
|
api_key=test_llm_request.api_key,
|
|
api_url=test_llm_request.api_url,
|
|
provider_type=test_llm_request.provider_type,
|
|
model_name=test_llm_request.model_name,
|
|
api_version=test_llm_request.api_version,
|
|
deployment_name=test_llm_request.deployment_name,
|
|
normalize=False,
|
|
query_prefix=None,
|
|
passage_prefix=None,
|
|
)
|
|
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
|
|
|
except ValueError as e:
|
|
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
except Exception as e:
|
|
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
|
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
|
raise HTTPException(status_code=400, detail=error_msg)
|
|
|
|
|
|
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
|
def list_embedding_models(
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> list[EmbeddingModelDetail]:
|
|
search_settings = get_all_search_settings(db_session)
|
|
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
|
|
|
|
|
|
@admin_router.get("/embedding-provider")
|
|
def list_embedding_providers(
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> list[CloudEmbeddingProvider]:
|
|
return [
|
|
CloudEmbeddingProvider.from_request(embedding_provider_model)
|
|
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
|
|
]
|
|
|
|
|
|
@admin_router.delete("/embedding-provider/{provider_type}")
|
|
def delete_embedding_provider(
|
|
provider_type: EmbeddingProvider,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
|
if (
|
|
embedding_provider is not None
|
|
and provider_type == embedding_provider.provider_type
|
|
):
|
|
raise HTTPException(
|
|
status_code=400, detail="You can't delete a currently active model"
|
|
)
|
|
|
|
remove_embedding_provider(db_session, provider_type=provider_type)
|
|
|
|
|
|
@admin_router.put("/embedding-provider")
|
|
def put_cloud_embedding_provider(
|
|
provider: CloudEmbeddingProviderCreationRequest,
|
|
_: User = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> CloudEmbeddingProvider:
|
|
return upsert_cloud_embedding_provider(db_session, provider)
|