diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 5ebc3d960..addcafaf7 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -90,6 +90,8 @@ GEN_AI_API_KEY = ( GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None # API Version, such as (for Azure): 2023-09-15-preview GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None +# LiteLLM custom_llm_provider +GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index af81f6be6..52811b7f3 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -9,6 +9,7 @@ from langchain.schema.language_model import LanguageModelInput from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_VERSION +from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION @@ -42,8 +43,10 @@ class LangChainChatLLM(LLM, abc.ABC): logger.debug(f"Prompt:\n{prompt}") def log_model_configs(self) -> None: + llm_dict = {k: v for k, v in self.llm.__dict__.items() if v} + llm_dict.pop("client") logger.info( - f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" + f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {llm_dict}" ) def invoke(self, prompt: LanguageModelInput) -> str: @@ -105,21 +108,25 @@ class DefaultMultiLLM(LangChainChatLLM): self, api_key: str | None, timeout: int, - model_provider: str | None = GEN_AI_MODEL_PROVIDER, - model_version: str | None = GEN_AI_MODEL_VERSION, + model_provider: str = GEN_AI_MODEL_PROVIDER, + model_version: str = GEN_AI_MODEL_VERSION, api_base: str | None = GEN_AI_API_ENDPOINT, api_version: str | None = GEN_AI_API_VERSION, + custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE, max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, temperature: float = GEN_AI_TEMPERATURE, ): # Litellm Langchain integration currently doesn't take in the api key param # Can place this in the call below once integration is in - litellm.api_key = api_key + litellm.api_key = api_key or "dummy-key" litellm.api_version = api_version self._llm = ChatLiteLLM( # type: ignore - model=_get_model_str(model_provider, model_version), + model=model_version + if custom_llm_provider + else _get_model_str(model_provider, model_version), api_base=api_base, + custom_llm_provider=custom_llm_provider, max_tokens=max_output_tokens, temperature=temperature, request_timeout=timeout, diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 19ddee6f6..8c59ee028 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -22,6 +22,7 @@ services: - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} + - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} - POSTGRES_HOST=relational_db - VESPA_HOST=index @@ -87,6 +88,7 @@ services: - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} + - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - POSTGRES_HOST=relational_db