mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 12:47:13 +02:00
121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
from typing import Any
|
|
|
|
import requests
|
|
from celery import shared_task
|
|
from celery import Task
|
|
|
|
from onyx.background.celery.apps.app_base import task_logger
|
|
from onyx.configs.app_configs import JOB_TIMEOUT
|
|
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
|
from onyx.configs.constants import OnyxCeleryTask
|
|
from onyx.db.engine import get_session_with_tenant
|
|
from onyx.db.models import LLMProvider
|
|
|
|
|
|
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
|
# Handle case where response is wrapped in a "data" field
|
|
if isinstance(model_list_json, dict):
|
|
if "data" in model_list_json:
|
|
model_list_json = model_list_json["data"]
|
|
elif "models" in model_list_json:
|
|
model_list_json = model_list_json["models"]
|
|
else:
|
|
raise ValueError(
|
|
"Invalid response from API - expected dict with 'data' or "
|
|
f"'models' field, got {type(model_list_json)}"
|
|
)
|
|
|
|
if not isinstance(model_list_json, list):
|
|
raise ValueError(
|
|
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
|
)
|
|
|
|
# Handle both string list and object list cases
|
|
model_names: list[str] = []
|
|
for item in model_list_json:
|
|
if isinstance(item, str):
|
|
model_names.append(item)
|
|
elif isinstance(item, dict):
|
|
if "model_name" in item:
|
|
model_names.append(item["model_name"])
|
|
elif "id" in item:
|
|
model_names.append(item["id"])
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
|
)
|
|
|
|
return model_names
|
|
|
|
|
|
@shared_task(
|
|
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
|
soft_time_limit=JOB_TIMEOUT,
|
|
trail=False,
|
|
bind=True,
|
|
)
|
|
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
|
if not LLM_MODEL_UPDATE_API_URL:
|
|
raise ValueError("LLM model update API URL not configured")
|
|
|
|
# First fetch the models from the API
|
|
try:
|
|
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
|
response.raise_for_status()
|
|
available_models = _process_model_list_response(response.json())
|
|
task_logger.info(f"Found available models: {available_models}")
|
|
|
|
except Exception:
|
|
task_logger.exception("Failed to fetch models from API.")
|
|
return None
|
|
|
|
# Then update the database with the fetched models
|
|
with get_session_with_tenant(tenant_id) as db_session:
|
|
# Get the default LLM provider
|
|
default_provider = (
|
|
db_session.query(LLMProvider)
|
|
.filter(LLMProvider.is_default_provider.is_(True))
|
|
.first()
|
|
)
|
|
|
|
if not default_provider:
|
|
task_logger.warning("No default LLM provider found")
|
|
return None
|
|
|
|
# log change if any
|
|
old_models = set(default_provider.model_names or [])
|
|
new_models = set(available_models)
|
|
added_models = new_models - old_models
|
|
removed_models = old_models - new_models
|
|
|
|
if added_models:
|
|
task_logger.info(f"Adding models: {sorted(added_models)}")
|
|
if removed_models:
|
|
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
|
|
|
# Update the provider's model list
|
|
default_provider.model_names = available_models
|
|
# if the default model is no longer available, set it to the first model in the list
|
|
if default_provider.default_model_name not in available_models:
|
|
task_logger.info(
|
|
f"Default model {default_provider.default_model_name} not "
|
|
f"available, setting to first model in list: {available_models[0]}"
|
|
)
|
|
default_provider.default_model_name = available_models[0]
|
|
if default_provider.fast_default_model_name not in available_models:
|
|
task_logger.info(
|
|
f"Fast default model {default_provider.fast_default_model_name} "
|
|
f"not available, setting to first model in list: {available_models[0]}"
|
|
)
|
|
default_provider.fast_default_model_name = available_models[0]
|
|
db_session.commit()
|
|
|
|
if added_models or removed_models:
|
|
task_logger.info("Updated model list for default provider.")
|
|
|
|
return True
|