from typing import Any from onyx.chat.models import PersonaOverrideConfig from onyx.configs.app_configs import DISABLE_GENERATIVE_AI from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine import get_session_context_manager from onyx.db.llm import fetch_default_provider from onyx.db.llm import fetch_provider from onyx.db.models import Persona from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger logger = setup_logger() def _build_extra_model_kwargs(provider: str) -> dict[str, Any]: """Ollama requires us to specify the max context window. For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value. TODO: allow model-specific values to be configured via the UI. """ return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {} def get_main_llm_from_tuple( llms: tuple[LLM, LLM], ) -> LLM: return llms[0] def get_llms_for_persona( persona: Persona | PersonaOverrideConfig | None, llm_override: LLMOverride | None = None, additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> tuple[LLM, LLM]: if persona is None: logger.warning("No persona provided, using default LLMs") return get_default_llms() model_provider_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None temperature_override = llm_override.temperature if llm_override else None provider_name = model_provider_override or persona.llm_model_provider_override if not provider_name: return get_default_llms( temperature=temperature_override or GEN_AI_TEMPERATURE, additional_headers=additional_headers, long_term_logger=long_term_logger, ) with get_session_context_manager() as db_session: llm_provider = fetch_provider(db_session, provider_name) if not llm_provider: raise ValueError("No LLM provider found") model = model_version_override or persona.llm_model_version_override fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name if not model: raise ValueError("No model name found") if not fast_model: raise ValueError("No fast model name found") def _create_llm(model: str) -> LLM: return get_llm( provider=llm_provider.provider, model=model, deployment_name=llm_provider.deployment_name, api_key=llm_provider.api_key, api_base=llm_provider.api_base, api_version=llm_provider.api_version, custom_config=llm_provider.custom_config, temperature=temperature_override, additional_headers=additional_headers, long_term_logger=long_term_logger, ) return _create_llm(model), _create_llm(fast_model) def get_default_llms( timeout: int | None = None, temperature: float | None = None, additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> tuple[LLM, LLM]: if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() with get_session_context_manager() as db_session: llm_provider = fetch_default_provider(db_session) if not llm_provider: raise ValueError("No default LLM provider found") model_name = llm_provider.default_model_name fast_model_name = ( llm_provider.fast_default_model_name or llm_provider.default_model_name ) if not model_name: raise ValueError("No default model name found") if not fast_model_name: raise ValueError("No fast default model name found") def _create_llm(model: str) -> LLM: return get_llm( provider=llm_provider.provider, model=model, deployment_name=llm_provider.deployment_name, api_key=llm_provider.api_key, api_base=llm_provider.api_base, api_version=llm_provider.api_version, custom_config=llm_provider.custom_config, timeout=timeout, temperature=temperature, additional_headers=additional_headers, long_term_logger=long_term_logger, ) return _create_llm(model_name), _create_llm(fast_model_name) def get_llm( provider: str, model: str, deployment_name: str | None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, custom_config: dict[str, str] | None = None, temperature: float | None = None, timeout: int | None = None, additional_headers: dict[str, str] | None = None, long_term_logger: LongTermLogger | None = None, ) -> LLM: if temperature is None: temperature = GEN_AI_TEMPERATURE return DefaultMultiLLM( model_provider=provider, model_name=model, deployment_name=deployment_name, api_key=api_key, api_base=api_base, api_version=api_version, timeout=timeout, temperature=temperature, custom_config=custom_config, extra_headers=build_llm_extra_headers(additional_headers), model_kwargs=_build_extra_model_kwargs(provider), long_term_logger=long_term_logger, )