mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-29 09:20:01 +02:00
Add support for auto-refreshing available models based on an API call (#3576)
This commit is contained in:
@ -286,5 +286,6 @@ celery_app.autodiscover_tasks(
|
|||||||
"onyx.background.celery.tasks.pruning",
|
"onyx.background.celery.tasks.pruning",
|
||||||
"onyx.background.celery.tasks.shared",
|
"onyx.background.celery.tasks.shared",
|
||||||
"onyx.background.celery.tasks.vespa",
|
"onyx.background.celery.tasks.vespa",
|
||||||
|
"onyx.background.celery.tasks.llm_model_update",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||||
from onyx.configs.constants import OnyxCeleryPriority
|
from onyx.configs.constants import OnyxCeleryPriority
|
||||||
from onyx.configs.constants import OnyxCeleryTask
|
from onyx.configs.constants import OnyxCeleryTask
|
||||||
|
|
||||||
@ -87,6 +88,20 @@ tasks_to_schedule = [
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Only add the LLM model update task if the API URL is configured
|
||||||
|
if LLM_MODEL_UPDATE_API_URL:
|
||||||
|
tasks_to_schedule.append(
|
||||||
|
{
|
||||||
|
"name": "check-for-llm-model-update",
|
||||||
|
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||||
|
"schedule": timedelta(hours=1), # Check every hour
|
||||||
|
"options": {
|
||||||
|
"priority": OnyxCeleryPriority.LOW,
|
||||||
|
"expires": BEAT_EXPIRES_DEFAULT,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||||
return tasks_to_schedule
|
return tasks_to_schedule
|
||||||
|
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
|
|
||||||
|
from onyx.background.celery.apps.app_base import task_logger
|
||||||
|
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||||
|
from onyx.configs.constants import OnyxCeleryTask
|
||||||
|
from onyx.db.engine import get_session_with_tenant
|
||||||
|
from onyx.db.models import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||||
|
# Handle case where response is wrapped in a "data" field
|
||||||
|
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||||
|
model_list_json = model_list_json["data"]
|
||||||
|
|
||||||
|
if not isinstance(model_list_json, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle both string list and object list cases
|
||||||
|
model_names: list[str] = []
|
||||||
|
for item in model_list_json:
|
||||||
|
if isinstance(item, str):
|
||||||
|
model_names.append(item)
|
||||||
|
elif isinstance(item, dict) and "model_name" in item:
|
||||||
|
model_names.append(item["model_name"])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_names
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
trail=False,
|
||||||
|
bind=True,
|
||||||
|
)
|
||||||
|
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||||
|
if not LLM_MODEL_UPDATE_API_URL:
|
||||||
|
raise ValueError("LLM model update API URL not configured")
|
||||||
|
|
||||||
|
# First fetch the models from the API
|
||||||
|
try:
|
||||||
|
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||||
|
response.raise_for_status()
|
||||||
|
available_models = _process_model_list_response(response.json())
|
||||||
|
task_logger.info(f"Found available models: {available_models}")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("Failed to fetch models from API.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Then update the database with the fetched models
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
# Get the default LLM provider
|
||||||
|
default_provider = (
|
||||||
|
db_session.query(LLMProvider)
|
||||||
|
.filter(LLMProvider.is_default_provider.is_(True))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not default_provider:
|
||||||
|
task_logger.warning("No default LLM provider found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# log change if any
|
||||||
|
old_models = set(default_provider.model_names or [])
|
||||||
|
new_models = set(available_models)
|
||||||
|
added_models = new_models - old_models
|
||||||
|
removed_models = old_models - new_models
|
||||||
|
|
||||||
|
if added_models:
|
||||||
|
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||||
|
if removed_models:
|
||||||
|
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||||
|
|
||||||
|
# Update the provider's model list
|
||||||
|
default_provider.model_names = available_models
|
||||||
|
# if the default model is no longer available, set it to the first model in the list
|
||||||
|
if default_provider.default_model_name not in available_models:
|
||||||
|
task_logger.info(
|
||||||
|
f"Default model {default_provider.default_model_name} not "
|
||||||
|
f"available, setting to first model in list: {available_models[0]}"
|
||||||
|
)
|
||||||
|
default_provider.default_model_name = available_models[0]
|
||||||
|
if default_provider.fast_default_model_name not in available_models:
|
||||||
|
task_logger.info(
|
||||||
|
f"Fast default model {default_provider.fast_default_model_name} "
|
||||||
|
f"not available, setting to first model in list: {available_models[0]}"
|
||||||
|
)
|
||||||
|
default_provider.fast_default_model_name = available_models[0]
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
if added_models or removed_models:
|
||||||
|
task_logger.info("Updated model list for default provider.")
|
||||||
|
|
||||||
|
return True
|
@ -537,6 +537,9 @@ try:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# LLM Model Update API endpoint
|
||||||
|
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Enterprise Edition Configs
|
# Enterprise Edition Configs
|
||||||
#####
|
#####
|
||||||
|
@ -249,6 +249,7 @@ class OnyxCeleryQueues:
|
|||||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||||
CONNECTOR_DELETION = "connector_deletion"
|
CONNECTOR_DELETION = "connector_deletion"
|
||||||
|
LLM_MODEL_UPDATE = "llm_model_update"
|
||||||
|
|
||||||
# Heavy queue
|
# Heavy queue
|
||||||
CONNECTOR_PRUNING = "connector_pruning"
|
CONNECTOR_PRUNING = "connector_pruning"
|
||||||
@ -304,6 +305,7 @@ class OnyxCeleryTask:
|
|||||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||||
|
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||||
|
Reference in New Issue
Block a user