mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Custom OpenAI Model Server (#782)
This commit is contained in:
parent
37daf4f3e4
commit
c2727a3f19
@ -90,6 +90,8 @@ GEN_AI_API_KEY = (
|
||||
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
# API Version, such as (for Azure): 2023-09-15-preview
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
|
@ -9,6 +9,7 @@ from langchain.schema.language_model import LanguageModelInput
|
||||
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
@ -42,8 +43,10 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
llm_dict = {k: v for k, v in self.llm.__dict__.items() if v}
|
||||
llm_dict.pop("client")
|
||||
logger.info(
|
||||
f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {llm_dict}"
|
||||
)
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
@ -105,21 +108,25 @@ class DefaultMultiLLM(LangChainChatLLM):
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str | None = GEN_AI_MODEL_PROVIDER,
|
||||
model_version: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||
api_version: str | None = GEN_AI_API_VERSION,
|
||||
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
):
|
||||
# Litellm Langchain integration currently doesn't take in the api key param
|
||||
# Can place this in the call below once integration is in
|
||||
litellm.api_key = api_key
|
||||
litellm.api_key = api_key or "dummy-key"
|
||||
litellm.api_version = api_version
|
||||
|
||||
self._llm = ChatLiteLLM( # type: ignore
|
||||
model=_get_model_str(model_provider, model_version),
|
||||
model=model_version
|
||||
if custom_llm_provider
|
||||
else _get_model_str(model_provider, model_version),
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
max_tokens=max_output_tokens,
|
||||
temperature=temperature,
|
||||
request_timeout=timeout,
|
||||
|
@ -22,6 +22,7 @@ services:
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
|
||||
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
@ -87,6 +88,7 @@ services:
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
|
||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||
- POSTGRES_HOST=relational_db
|
||||
|
Loading…
x
Reference in New Issue
Block a user