support for passing extra headers to litellm using env variables

This commit is contained in:
Mehmet Bektas 2024-05-03 14:16:03 -07:00 committed by Chris Weaver
parent 2ff207218e
commit 6cbfe1bcdb
4 changed files with 24 additions and 4 deletions

View File

@ -1,3 +1,4 @@
import json
import os
#####
@ -97,3 +98,15 @@ GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
DISABLE_LITELLM_STREAMING = (
os.environ.get("DISABLE_LITELLM_STREAMING") or "false"
).lower() == "true"
# extra headers to pass to LiteLLM
LITELLM_EXTRA_HEADERS = None
if os.environ.get("LITELLM_EXTRA_HEADERS"):
try:
LITELLM_EXTRA_HEADERS = json.loads(os.environ.get("LITELLM_EXTRA_HEADERS"))
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
)

View File

@ -106,6 +106,7 @@ class DefaultMultiLLM(LangChainChatLLM):
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
):
self._model_provider = model_provider
self._model_version = model_name
@ -123,6 +124,11 @@ class DefaultMultiLLM(LangChainChatLLM):
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = DefaultMultiLLM.DEFAULT_MODEL_PARAMS if model_provider == "openai" else {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
self._llm = ChatLiteLLM( # type: ignore
model=(
model_name if custom_llm_provider else f"{model_provider}/{model_name}"
@ -134,9 +140,7 @@ class DefaultMultiLLM(LangChainChatLLM):
request_timeout=timeout,
# LiteLLM and some model providers don't handle these params well
# only turning it on for OpenAI
model_kwargs=DefaultMultiLLM.DEFAULT_MODEL_PARAMS
if model_provider == "openai"
else {},
model_kwargs=model_kwargs,
verbose=should_be_verbose(),
max_retries=0, # retries are handled outside of langchain
)

View File

@ -1,6 +1,6 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import GEN_AI_TEMPERATURE, LITELLM_EXTRA_HEADERS
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import fetch_provider
@ -87,4 +87,5 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=LITELLM_EXTRA_HEADERS,
)

View File

@ -46,6 +46,7 @@ services:
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
# if set, allows for the use of the token budget system
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
# Enables the use of bedrock models
@ -122,6 +123,7 @@ services:
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-}
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)