diff --git a/backend/danswer/llm/llm_initialization.py b/backend/danswer/llm/llm_initialization.py new file mode 100644 index 000000000..5c6f8bdbe --- /dev/null +++ b/backend/danswer/llm/llm_initialization.py @@ -0,0 +1,78 @@ +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_API_ENDPOINT +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_API_VERSION +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider +from danswer.llm.llm_provider_options import AZURE_PROVIDER_NAME +from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME +from danswer.llm.llm_provider_options import fetch_available_well_known_llms +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def load_llm_providers(db_session: Session) -> None: + existing_providers = fetch_existing_llm_providers(db_session) + if existing_providers: + return + + if not GEN_AI_API_KEY or DISABLE_GENERATIVE_AI: + return + + well_known_provider_name_to_provider = { + provider.name: provider + for provider in fetch_available_well_known_llms() + if provider.name != BEDROCK_PROVIDER_NAME + } + + if GEN_AI_MODEL_PROVIDER not in well_known_provider_name_to_provider: + logger.error(f"Cannot auto-transition LLM provider: {GEN_AI_MODEL_PROVIDER}") + return None + + # Azure provider requires custom model names, + # OpenAI / anthropic can just use the defaults + model_names = ( + [ + name + for name in [ + GEN_AI_MODEL_VERSION, + FAST_GEN_AI_MODEL_VERSION, + ] + if name + ] + if GEN_AI_MODEL_PROVIDER == AZURE_PROVIDER_NAME + else None + ) + + well_known_provider = well_known_provider_name_to_provider[GEN_AI_MODEL_PROVIDER] + llm_provider_request = LLMProviderUpsertRequest( + name=well_known_provider.display_name, + provider=GEN_AI_MODEL_PROVIDER, + api_key=GEN_AI_API_KEY, + api_base=GEN_AI_API_ENDPOINT, + api_version=GEN_AI_API_VERSION, + custom_config={}, + default_model_name=( + GEN_AI_MODEL_VERSION + or well_known_provider.default_model + or well_known_provider.llm_names[0] + ), + fast_default_model_name=( + FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model + ), + model_names=model_names, + ) + llm_provider = upsert_llm_provider(db_session, llm_provider_request) + update_default_provider(db_session, llm_provider.id) + logger.info( + f"Migrated LLM provider from env variables for provider '{GEN_AI_MODEL_PROVIDER}'" + ) diff --git a/backend/danswer/llm/options.py b/backend/danswer/llm/llm_provider_options.py similarity index 99% rename from backend/danswer/llm/options.py rename to backend/danswer/llm/llm_provider_options.py index 31b0060b2..3b2c62c6c 100644 --- a/backend/danswer/llm/options.py +++ b/backend/danswer/llm/llm_provider_options.py @@ -11,7 +11,7 @@ class CustomConfigKey(BaseModel): class WellKnownLLMProviderDescriptor(BaseModel): name: str - display_name: str | None = None + display_name: str api_key_required: bool api_base_required: bool api_version_required: bool diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 99e7e1c12..d717f06a4 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -46,6 +46,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index +from danswer.llm.llm_initialization import load_llm_providers from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.auth_check import check_router_auth @@ -199,6 +200,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: create_initial_default_connector(db_session) associate_default_cc_pair(db_session) + logger.info("Loading LLM providers from env variables") + load_llm_providers(db_session) + logger.info("Loading default Prompts and Personas") delete_old_default_personas(db_session) load_chat_yamls() diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 3b4673522..7bc4efe63 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -15,8 +15,8 @@ from danswer.db.llm import upsert_llm_provider from danswer.db.models import User from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_llm -from danswer.llm.options import fetch_available_well_known_llms -from danswer.llm.options import WellKnownLLMProviderDescriptor +from danswer.llm.llm_provider_options import fetch_available_well_known_llms +from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor from danswer.llm.utils import test_llm from danswer.server.manage.llm.models import FullLLMProvider from danswer.server.manage.llm.models import LLMProviderDescriptor diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index c989063df..0e791696a 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from pydantic import BaseModel -from danswer.llm.options import fetch_models_for_provider +from danswer.llm.llm_provider_options import fetch_models_for_provider if TYPE_CHECKING: from danswer.db.models import LLMProvider as LLMProviderModel diff --git a/web/src/app/admin/models/llm/interfaces.ts b/web/src/app/admin/models/llm/interfaces.ts index 90a06d19d..d78ce605e 100644 --- a/web/src/app/admin/models/llm/interfaces.ts +++ b/web/src/app/admin/models/llm/interfaces.ts @@ -7,7 +7,7 @@ export interface CustomConfigKey { export interface WellKnownLLMProviderDescriptor { name: string; - display_name: string | null; + display_name: string; api_key_required: boolean; api_base_required: boolean;