Add migration for env variable based LLM providers

This commit is contained in:
Weves 2024-05-20 12:42:03 -07:00 committed by Chris Weaver
parent 4413c0df36
commit bbae63b769
6 changed files with 87 additions and 5 deletions

View File

@ -0,0 +1,78 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.llm.llm_provider_options import AZURE_PROVIDER_NAME
from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
def load_llm_providers(db_session: Session) -> None:
existing_providers = fetch_existing_llm_providers(db_session)
if existing_providers:
return
if not GEN_AI_API_KEY or DISABLE_GENERATIVE_AI:
return
well_known_provider_name_to_provider = {
provider.name: provider
for provider in fetch_available_well_known_llms()
if provider.name != BEDROCK_PROVIDER_NAME
}
if GEN_AI_MODEL_PROVIDER not in well_known_provider_name_to_provider:
logger.error(f"Cannot auto-transition LLM provider: {GEN_AI_MODEL_PROVIDER}")
return None
# Azure provider requires custom model names,
# OpenAI / anthropic can just use the defaults
model_names = (
[
name
for name in [
GEN_AI_MODEL_VERSION,
FAST_GEN_AI_MODEL_VERSION,
]
if name
]
if GEN_AI_MODEL_PROVIDER == AZURE_PROVIDER_NAME
else None
)
well_known_provider = well_known_provider_name_to_provider[GEN_AI_MODEL_PROVIDER]
llm_provider_request = LLMProviderUpsertRequest(
name=well_known_provider.display_name,
provider=GEN_AI_MODEL_PROVIDER,
api_key=GEN_AI_API_KEY,
api_base=GEN_AI_API_ENDPOINT,
api_version=GEN_AI_API_VERSION,
custom_config={},
default_model_name=(
GEN_AI_MODEL_VERSION
or well_known_provider.default_model
or well_known_provider.llm_names[0]
),
fast_default_model_name=(
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
),
model_names=model_names,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
update_default_provider(db_session, llm_provider.id)
logger.info(
f"Migrated LLM provider from env variables for provider '{GEN_AI_MODEL_PROVIDER}'"
)

View File

@ -11,7 +11,7 @@ class CustomConfigKey(BaseModel):
class WellKnownLLMProviderDescriptor(BaseModel):
name: str
display_name: str | None = None
display_name: str
api_key_required: bool
api_base_required: bool
api_version_required: bool

View File

@ -46,6 +46,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.llm.llm_initialization import load_llm_providers
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
@ -199,6 +200,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()

View File

@ -15,8 +15,8 @@ from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_llm
from danswer.llm.options import fetch_available_well_known_llms
from danswer.llm.options import WellKnownLLMProviderDescriptor
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderDescriptor

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
from danswer.llm.options import fetch_models_for_provider
from danswer.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel

View File

@ -7,7 +7,7 @@ export interface CustomConfigKey {
export interface WellKnownLLMProviderDescriptor {
name: string;
display_name: string | null;
display_name: string;
api_key_required: boolean;
api_base_required: boolean;