diff --git a/backend/danswer/llm/openai.py b/backend/danswer/llm/openai.py index 90fce9035..48673aa22 100644 --- a/backend/danswer/llm/openai.py +++ b/backend/danswer/llm/openai.py @@ -1,14 +1,26 @@ from typing import Any from typing import cast -from langchain.chat_models.openai import ChatOpenAI +from langchain.chat_models import ChatLiteLLM +import litellm # type:ignore from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.llm import LangChainChatLLM from danswer.llm.utils import should_be_verbose +# If a user configures a different model and it doesn't support all the same +# parameters like frequency and presence, just ignore them +litellm.drop_params=True + + class OpenAIGPT(LangChainChatLLM): + + DEFAULT_MODEL_PARAMS = { + "frequency_penalty": 0, + "presence_penalty": 0, + } + def __init__( self, api_key: str, @@ -19,22 +31,20 @@ class OpenAIGPT(LangChainChatLLM): *args: list[Any], **kwargs: dict[str, Any] ): - self._llm = ChatOpenAI( + litellm.api_key = api_key + + self._llm = ChatLiteLLM( # type: ignore model=model_version, - openai_api_key=api_key, # Prefer using None which is the default value, endpoint could be empty string - openai_api_base=cast(str, kwargs.get("endpoint")) or None, + api_base=cast(str, kwargs.get("endpoint")) or None, max_tokens=max_output_tokens, temperature=temperature, request_timeout=timeout, - model_kwargs={ - "frequency_penalty": 0, - "presence_penalty": 0, - }, + model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS, verbose=should_be_verbose(), max_retries=0, # retries are handled outside of langchain ) @property - def llm(self) -> ChatOpenAI: + def llm(self) -> ChatLiteLLM: return self._llm