Ability to pass through headers to LLM call

This commit is contained in:
Weves 2024-06-10 13:02:05 -07:00 committed by Chris Weaver
parent 180b592afe
commit b723627e0c
13 changed files with 99 additions and 35 deletions

View File

@ -193,6 +193,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@ -228,7 +229,9 @@ def stream_chat_message_objects(
try:
llm = get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
@ -410,7 +413,7 @@ def stream_chat_message_objects(
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
llm=llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
@ -455,7 +458,9 @@ def stream_chat_message_objects(
llm=(
llm
or get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
),
message_history=[
@ -576,6 +581,7 @@ def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@ -583,6 +589,7 @@ def stream_chat_message(
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@ -100,7 +100,7 @@ DISABLE_LITELLM_STREAMING = (
).lower() == "true"
# extra headers to pass to LiteLLM
LITELLM_EXTRA_HEADERS = None
LITELLM_EXTRA_HEADERS: dict[str, str] | None = None
_LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS")
if _LITELLM_EXTRA_HEADERS_RAW:
try:
@ -113,3 +113,18 @@ if _LITELLM_EXTRA_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
)
# if specified, will pass through request headers to the call to the LLM
LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None
_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS")
if _LITELLM_PASS_THROUGH_HEADERS_RAW:
try:
LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@ -13,7 +13,9 @@ from danswer.llm.override_models import LLMOverride
def get_llm_for_persona(
persona: Persona, llm_override: LLMOverride | None = None
persona: Persona,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
@ -25,6 +27,7 @@ def get_llm_for_persona(
),
model_version=(model_version_override or persona.llm_model_version_override),
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
)
@ -34,6 +37,7 @@ def get_default_llm(
use_fast_llm: bool = False,
model_provider_name: str | None = None,
model_version: str | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
@ -65,6 +69,7 @@ def get_default_llm(
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)
@ -77,7 +82,14 @@ def get_llm(
custom_config: dict[str, str] | None = None,
temperature: float = GEN_AI_TEMPERATURE,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
) -> LLM:
extra_headers = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
@ -87,5 +99,5 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=LITELLM_EXTRA_HEADERS,
extra_headers=extra_headers,
)

View File

@ -0,0 +1,22 @@
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}
pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers

View File

@ -172,7 +172,7 @@ def stream_answer_objects(
persona=chat_session.persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm_config=llm.config,
llm=llm,
pruning_config=document_pruning_config,
)

View File

@ -10,6 +10,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
@ -54,6 +55,7 @@ class SearchPipeline:
self,
search_request: SearchRequest,
user: User | None,
llm: LLM,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
@ -62,6 +64,7 @@ class SearchPipeline:
):
self.search_request = search_request
self.user = user
self.llm = llm
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
@ -229,6 +232,7 @@ class SearchPipeline:
) = retrieval_preprocessing(
search_request=self.search_request,
user=self.user,
llm=self.llm,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)

View File

@ -6,6 +6,7 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import BaseFilters
@ -31,6 +32,7 @@ logger = setup_logger()
def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
@ -87,14 +89,14 @@ def retrieval_preprocessing(
# Based on the query figure out if we should apply any hard time filters /
# if we should bias more recent docs even more strongly
run_time_filters = (
FunctionCall(extract_time_filter, (query,), {})
FunctionCall(extract_time_filter, (query, llm), {})
if auto_detect_time_filter
else None
)
# Based on the query, figure out if we should apply any source filters
run_source_filters = (
FunctionCall(extract_source_filter, (query, db_session), {})
FunctionCall(extract_source_filter, (query, llm, db_session), {})
if auto_detect_source_filter
else None
)

View File

@ -6,8 +6,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import SOURCES_KEY
@ -44,7 +43,7 @@ def _sample_document_sources(
def extract_source_filter(
query: str, db_session: Session
query: str, llm: LLM, db_session: Session
) -> list[DocumentSource] | None:
"""Returns a list of valid sources for search or None if no specific sources were detected"""
@ -147,11 +146,6 @@ def extract_source_filter(
logger.warning("LLM failed to provide a valid Source Filter output")
return None
try:
llm = get_default_llm()
except GenAIDisabledException:
return None
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
@ -165,9 +159,11 @@ def extract_source_filter(
if __name__ == "__main__":
from danswer.llm.factory import get_default_llm
# Just for testing purposes
with Session(get_sqlalchemy_engine()) as db_session:
while True:
user_input = input("Query to Extract Sources: ")
sources = extract_source_filter(user_input, db_session)
sources = extract_source_filter(user_input, get_default_llm(), db_session)
print(sources)

View File

@ -5,8 +5,7 @@ from datetime import timezone
from dateutil.parser import parse
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
@ -41,7 +40,7 @@ def best_match_time(time_str: str) -> datetime | None:
return None
def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]:
"""Returns a datetime if a hard time filter should be applied for the given query
Additionally returns a bool, True if more recently updated Documents should be
heavily favored"""
@ -147,11 +146,6 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
return None, False
try:
llm = get_default_llm()
except GenAIDisabledException:
return None, False
messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
@ -162,8 +156,10 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
if __name__ == "__main__":
# Just for testing purposes, too tedious to unit test as it relies on an LLM
from danswer.llm.factory import get_default_llm
while True:
user_input = input("Query to Extract Time: ")
cutoff, recency_bias = extract_time_filter(user_input)
cutoff, recency_bias = extract_time_filter(user_input, get_default_llm())
print(f"Time Cutoff: {cutoff}")
print(f"Favor Recent: {recency_bias}")

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.engine import get_session
from danswer.llm.factory import get_default_llm
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.server.danswer_api.ingestion import api_key_dep
@ -71,6 +72,7 @@ def gpt_search(
query=search_request.query,
),
user=None,
llm=get_default_llm(),
db_session=db_session,
).reranked_chunks

View File

@ -4,6 +4,7 @@ import uuid
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
@ -41,6 +42,7 @@ from danswer.file_store.models import FileDescriptor
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
@ -233,6 +235,7 @@ def delete_chat_session_by_id(
@router.post("/send-message")
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_user),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
@ -256,6 +259,9 @@ def handle_new_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
)
return StreamingResponse(packets, media_type="application/json")

View File

@ -15,7 +15,6 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
@ -63,7 +62,7 @@ class SearchTool(Tool):
persona: Persona,
retrieval_options: RetrievalDetails | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
llm: LLM,
pruning_config: DocumentPruningConfig,
# if specified, will not actually run a search and will instead return these
# sections. Used when the user selects specific docs to talk to
@ -76,7 +75,7 @@ class SearchTool(Tool):
self.persona = persona
self.retrieval_options = retrieval_options
self.prompt_config = prompt_config
self.llm_config = llm_config
self.llm = llm
self.pruning_config = pruning_config
self.selected_docs = selected_docs
@ -175,7 +174,7 @@ class SearchTool(Tool):
docs=self.selected_docs,
doc_relevance_list=None,
prompt_config=self.prompt_config,
llm_config=self.llm_config,
llm_config=self.llm.config,
question=query,
document_pruning_config=self.pruning_config,
),
@ -191,9 +190,9 @@ class SearchTool(Tool):
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
human_selected_filters=self.retrieval_options.filters
if self.retrieval_options
else None,
human_selected_filters=(
self.retrieval_options.filters if self.retrieval_options else None
),
persona=self.persona,
offset=self.retrieval_options.offset
if self.retrieval_options
@ -204,6 +203,7 @@ class SearchTool(Tool):
full_doc=self.full_doc,
),
user=self.user,
llm=self.llm,
db_session=self.db_session,
)
yield ToolResponse(
@ -233,7 +233,7 @@ class SearchTool(Tool):
for ind in range(len(llm_docs))
],
prompt_config=self.prompt_config,
llm_config=self.llm_config,
llm_config=self.llm.config,
question=query,
document_pruning_config=self.pruning_config,
)

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.answering.doc_pruning import reorder_docs
from danswer.llm.factory import get_default_llm
from danswer.search.models import InferenceChunk
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
@ -87,6 +88,7 @@ def get_search_results(
query=query,
),
user=None,
llm=get_default_llm(),
db_session=db_session,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,