mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-23 12:31:30 +02:00
Add LiteLLM Support - Anthropic, Bedrock, Huggingface, TogetherAI, Replicate, etc. (#510)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
@@ -1,14 +1,26 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models import ChatLiteLLM
|
||||||
|
import litellm # type:ignore
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.llm.llm import LangChainChatLLM
|
from danswer.llm.llm import LangChainChatLLM
|
||||||
from danswer.llm.utils import should_be_verbose
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
|
||||||
|
|
||||||
|
# If a user configures a different model and it doesn't support all the same
|
||||||
|
# parameters like frequency and presence, just ignore them
|
||||||
|
litellm.drop_params=True
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPT(LangChainChatLLM):
|
class OpenAIGPT(LangChainChatLLM):
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PARAMS = {
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
@@ -19,22 +31,20 @@ class OpenAIGPT(LangChainChatLLM):
|
|||||||
*args: list[Any],
|
*args: list[Any],
|
||||||
**kwargs: dict[str, Any]
|
**kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
self._llm = ChatOpenAI(
|
litellm.api_key = api_key
|
||||||
|
|
||||||
|
self._llm = ChatLiteLLM( # type: ignore
|
||||||
model=model_version,
|
model=model_version,
|
||||||
openai_api_key=api_key,
|
|
||||||
# Prefer using None which is the default value, endpoint could be empty string
|
# Prefer using None which is the default value, endpoint could be empty string
|
||||||
openai_api_base=cast(str, kwargs.get("endpoint")) or None,
|
api_base=cast(str, kwargs.get("endpoint")) or None,
|
||||||
max_tokens=max_output_tokens,
|
max_tokens=max_output_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
request_timeout=timeout,
|
request_timeout=timeout,
|
||||||
model_kwargs={
|
model_kwargs=OpenAIGPT.DEFAULT_MODEL_PARAMS,
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
},
|
|
||||||
verbose=should_be_verbose(),
|
verbose=should_be_verbose(),
|
||||||
max_retries=0, # retries are handled outside of langchain
|
max_retries=0, # retries are handled outside of langchain
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llm(self) -> ChatOpenAI:
|
def llm(self) -> ChatLiteLLM:
|
||||||
return self._llm
|
return self._llm
|
||||||
|
Reference in New Issue
Block a user