Add LiteLLM Support - Anthropic, Bedrock, Huggingface, TogetherAI, Replicate, etc. (#510)

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
Krish Dholakia 2023-10-31 12:01:15 -07:00 committed by GitHub
parent c6663d83d5
commit ee0d092dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,14 +1,26 @@
from typing import Any
from typing import cast
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models import ChatLiteLLM
import litellm # type:ignore
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params=True
class OpenAIGPT(LangChainChatLLM):
DEFAULT_MODEL_PARAMS = {
"frequency_penalty": 0,
"presence_penalty": 0,
}
def __init__(
self,
api_key: str,
@ -19,22 +31,20 @@ class OpenAIGPT(LangChainChatLLM):
*args: list[Any],
**kwargs: dict[str, Any]
):
self._llm = ChatOpenAI(
litellm.api_key = api_key
self._llm = ChatLiteLLM( # type: ignore
model=model_version,
openai_api_key=api_key,
# Prefer using None which is the default value, endpoint could be empty string
openai_api_base=cast(str, kwargs.get("endpoint")) or None,
api_base=cast(str, kwargs.get("endpoint")) or None,
max_tokens=max_output_tokens,
temperature=temperature,
request_timeout=timeout,
model_kwargs={
"frequency_penalty": 0,
"presence_penalty": 0,
},
model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain
)
@property
def llm(self) -> ChatOpenAI:
def llm(self) -> ChatLiteLLM:
return self._llm