Add LiteLLM Support - Anthropic, Bedrock, Huggingface, TogetherAI, Replicate, etc. (#510)

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
Krish Dholakia
2023-10-31 12:01:15 -07:00
committed by GitHub
parent c6663d83d5
commit ee0d092dcc

View File

@@ -1,14 +1,26 @@
from typing import Any from typing import Any
from typing import cast from typing import cast
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models import ChatLiteLLM
import litellm # type:ignore
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose from danswer.llm.utils import should_be_verbose
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params=True
class OpenAIGPT(LangChainChatLLM): class OpenAIGPT(LangChainChatLLM):
DEFAULT_MODEL_PARAMS = {
"frequency_penalty": 0,
"presence_penalty": 0,
}
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
@@ -19,22 +31,20 @@ class OpenAIGPT(LangChainChatLLM):
*args: list[Any], *args: list[Any],
**kwargs: dict[str, Any] **kwargs: dict[str, Any]
): ):
self._llm = ChatOpenAI( litellm.api_key = api_key
self._llm = ChatLiteLLM( # type: ignore
model=model_version, model=model_version,
openai_api_key=api_key,
# Prefer using None which is the default value, endpoint could be empty string # Prefer using None which is the default value, endpoint could be empty string
openai_api_base=cast(str, kwargs.get("endpoint")) or None, api_base=cast(str, kwargs.get("endpoint")) or None,
max_tokens=max_output_tokens, max_tokens=max_output_tokens,
temperature=temperature, temperature=temperature,
request_timeout=timeout, request_timeout=timeout,
model_kwargs={ model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
"frequency_penalty": 0,
"presence_penalty": 0,
},
verbose=should_be_verbose(), verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain max_retries=0, # retries are handled outside of langchain
) )
@property @property
def llm(self) -> ChatOpenAI: def llm(self) -> ChatLiteLLM:
return self._llm return self._llm