From b723627e0cf989cfde5734f825bfac5fdde126bf Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 10 Jun 2024 13:02:05 -0700 Subject: [PATCH] Ability to pass through headers to LLM call --- backend/danswer/chat/process_message.py | 13 ++++++++--- backend/danswer/configs/model_configs.py | 17 +++++++++++++- backend/danswer/llm/factory.py | 16 ++++++++++++-- backend/danswer/llm/headers.py | 22 +++++++++++++++++++ .../one_shot_answer/answer_question.py | 2 +- backend/danswer/search/pipeline.py | 4 ++++ .../search/preprocessing/preprocessing.py | 6 +++-- .../secondary_llm_flows/source_filter.py | 14 +++++------- .../secondary_llm_flows/time_filter.py | 14 +++++------- backend/danswer/server/gpts/api.py | 2 ++ .../server/query_and_chat/chat_backend.py | 6 +++++ backend/danswer/tools/search/search_tool.py | 16 +++++++------- .../regression/search_quality/eval_search.py | 2 ++ 13 files changed, 99 insertions(+), 35 deletions(-) create mode 100644 backend/danswer/llm/headers.py diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 7ba7d991a..723815421 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -193,6 +193,7 @@ def stream_chat_message_objects( # on the `new_msg_req.message`. Currently, requires a state where the last message is a # user message (e.g. this can only be used for the chat-seeding flow). use_existing_user_message: bool = False, + litellm_additional_headers: dict[str, str] | None = None, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -228,7 +229,9 @@ def stream_chat_message_objects( try: llm = get_llm_for_persona( - persona, new_msg_req.llm_override or chat_session.llm_override + persona=persona, + llm_override=new_msg_req.llm_override or chat_session.llm_override, + additional_headers=litellm_additional_headers, ) except GenAIDisabledException: raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.") @@ -410,7 +413,7 @@ def stream_chat_message_objects( persona=persona, retrieval_options=retrieval_options, prompt_config=prompt_config, - llm_config=llm.config, + llm=llm, pruning_config=document_pruning_config, selected_docs=selected_llm_docs, chunks_above=new_msg_req.chunks_above, @@ -455,7 +458,9 @@ def stream_chat_message_objects( llm=( llm or get_llm_for_persona( - persona, new_msg_req.llm_override or chat_session.llm_override + persona=persona, + llm_override=new_msg_req.llm_override or chat_session.llm_override, + additional_headers=litellm_additional_headers, ) ), message_history=[ @@ -576,6 +581,7 @@ def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, use_existing_user_message: bool = False, + litellm_additional_headers: dict[str, str] | None = None, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( @@ -583,6 +589,7 @@ def stream_chat_message( user=user, db_session=db_session, use_existing_user_message=use_existing_user_message, + litellm_additional_headers=litellm_additional_headers, ) for obj in objects: yield get_json_line(obj.dict()) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 151b41811..f5be89779 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -100,7 +100,7 @@ DISABLE_LITELLM_STREAMING = ( ).lower() == "true" # extra headers to pass to LiteLLM -LITELLM_EXTRA_HEADERS = None +LITELLM_EXTRA_HEADERS: dict[str, str] | None = None _LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS") if _LITELLM_EXTRA_HEADERS_RAW: try: @@ -113,3 +113,18 @@ if _LITELLM_EXTRA_HEADERS_RAW: logger.error( "Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object" ) + +# if specified, will pass through request headers to the call to the LLM +LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None +_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS") +if _LITELLM_PASS_THROUGH_HEADERS_RAW: + try: + LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW) + except Exception: + # need to import here to avoid circular imports + from danswer.utils.logger import setup_logger + + logger = setup_logger() + logger.error( + "Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object" + ) diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 9c92eb9a6..d85dbdc9f 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -13,7 +13,9 @@ from danswer.llm.override_models import LLMOverride def get_llm_for_persona( - persona: Persona, llm_override: LLMOverride | None = None + persona: Persona, + llm_override: LLMOverride | None = None, + additional_headers: dict[str, str] | None = None, ) -> LLM: model_provider_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None @@ -25,6 +27,7 @@ def get_llm_for_persona( ), model_version=(model_version_override or persona.llm_model_version_override), temperature=temperature_override or GEN_AI_TEMPERATURE, + additional_headers=additional_headers, ) @@ -34,6 +37,7 @@ def get_default_llm( use_fast_llm: bool = False, model_provider_name: str | None = None, model_version: str | None = None, + additional_headers: dict[str, str] | None = None, ) -> LLM: if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() @@ -65,6 +69,7 @@ def get_default_llm( custom_config=llm_provider.custom_config, timeout=timeout, temperature=temperature, + additional_headers=additional_headers, ) @@ -77,7 +82,14 @@ def get_llm( custom_config: dict[str, str] | None = None, temperature: float = GEN_AI_TEMPERATURE, timeout: int = QA_TIMEOUT, + additional_headers: dict[str, str] | None = None, ) -> LLM: + extra_headers = {} + if additional_headers: + extra_headers.update(additional_headers) + if LITELLM_EXTRA_HEADERS: + extra_headers.update(LITELLM_EXTRA_HEADERS) + return DefaultMultiLLM( model_provider=provider, model_name=model, @@ -87,5 +99,5 @@ def get_llm( timeout=timeout, temperature=temperature, custom_config=custom_config, - extra_headers=LITELLM_EXTRA_HEADERS, + extra_headers=extra_headers, ) diff --git a/backend/danswer/llm/headers.py b/backend/danswer/llm/headers.py new file mode 100644 index 000000000..f7ae7436f --- /dev/null +++ b/backend/danswer/llm/headers.py @@ -0,0 +1,22 @@ +from fastapi.datastructures import Headers + +from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS + + +def get_litellm_additional_request_headers( + headers: dict[str, str] | Headers +) -> dict[str, str]: + if not LITELLM_PASS_THROUGH_HEADERS: + return {} + + pass_through_headers: dict[str, str] = {} + for key in LITELLM_PASS_THROUGH_HEADERS: + if key in headers: + pass_through_headers[key] = headers[key] + else: + # fastapi makes all header keys lowercase, handling that here + lowercase_key = key.lower() + if lowercase_key in headers: + pass_through_headers[lowercase_key] = headers[lowercase_key] + + return pass_through_headers diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 39ac46ba6..3b29920ac 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -172,7 +172,7 @@ def stream_answer_objects( persona=chat_session.persona, retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, - llm_config=llm.config, + llm=llm, pruning_config=document_pruning_config, ) diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 0c757232a..7f68178bf 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -10,6 +10,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import User from danswer.document_index.factory import get_default_document_index +from danswer.llm.interfaces import LLM from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType from danswer.search.models import IndexFilters @@ -54,6 +55,7 @@ class SearchPipeline: self, search_request: SearchRequest, user: User | None, + llm: LLM, db_session: Session, bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -62,6 +64,7 @@ class SearchPipeline: ): self.search_request = search_request self.user = user + self.llm = llm self.db_session = db_session self.bypass_acl = bypass_acl self.retrieval_metrics_callback = retrieval_metrics_callback @@ -229,6 +232,7 @@ class SearchPipeline: ) = retrieval_preprocessing( search_request=self.search_request, user=self.user, + llm=self.llm, db_session=self.db_session, bypass_acl=self.bypass_acl, ) diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index ab22c5d67..b4be8dca6 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -6,6 +6,7 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.db.models import User +from danswer.llm.interfaces import LLM from danswer.search.enums import QueryFlow from danswer.search.enums import RecencyBiasSetting from danswer.search.models import BaseFilters @@ -31,6 +32,7 @@ logger = setup_logger() def retrieval_preprocessing( search_request: SearchRequest, user: User | None, + llm: LLM, db_session: Session, bypass_acl: bool = False, include_query_intent: bool = True, @@ -87,14 +89,14 @@ def retrieval_preprocessing( # Based on the query figure out if we should apply any hard time filters / # if we should bias more recent docs even more strongly run_time_filters = ( - FunctionCall(extract_time_filter, (query,), {}) + FunctionCall(extract_time_filter, (query, llm), {}) if auto_detect_time_filter else None ) # Based on the query, figure out if we should apply any source filters run_source_filters = ( - FunctionCall(extract_source_filter, (query, db_session), {}) + FunctionCall(extract_source_filter, (query, llm, db_session), {}) if auto_detect_source_filter else None ) diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 6a27963ff..78dc504ab 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -6,8 +6,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_unique_document_sources from danswer.db.engine import get_sqlalchemy_engine -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.constants import SOURCES_KEY @@ -44,7 +43,7 @@ def _sample_document_sources( def extract_source_filter( - query: str, db_session: Session + query: str, llm: LLM, db_session: Session ) -> list[DocumentSource] | None: """Returns a list of valid sources for search or None if no specific sources were detected""" @@ -147,11 +146,6 @@ def extract_source_filter( logger.warning("LLM failed to provide a valid Source Filter output") return None - try: - llm = get_default_llm() - except GenAIDisabledException: - return None - valid_sources = fetch_unique_document_sources(db_session) if not valid_sources: return None @@ -165,9 +159,11 @@ def extract_source_filter( if __name__ == "__main__": + from danswer.llm.factory import get_default_llm + # Just for testing purposes with Session(get_sqlalchemy_engine()) as db_session: while True: user_input = input("Query to Extract Sources: ") - sources = extract_source_filter(user_input, db_session) + sources = extract_source_filter(user_input, get_default_llm(), db_session) print(sources) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index 9080dc1f9..7de6efb3d 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -5,8 +5,7 @@ from datetime import timezone from dateutil.parser import parse -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.filter_extration import TIME_FILTER_PROMPT @@ -41,7 +40,7 @@ def best_match_time(time_str: str) -> datetime | None: return None -def extract_time_filter(query: str) -> tuple[datetime | None, bool]: +def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]: """Returns a datetime if a hard time filter should be applied for the given query Additionally returns a bool, True if more recently updated Documents should be heavily favored""" @@ -147,11 +146,6 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: return None, False - try: - llm = get_default_llm() - except GenAIDisabledException: - return None, False - messages = _get_time_filter_messages(query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) model_output = message_to_string(llm.invoke(filled_llm_prompt)) @@ -162,8 +156,10 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: if __name__ == "__main__": # Just for testing purposes, too tedious to unit test as it relies on an LLM + from danswer.llm.factory import get_default_llm + while True: user_input = input("Query to Extract Time: ") - cutoff, recency_bias = extract_time_filter(user_input) + cutoff, recency_bias = extract_time_filter(user_input, get_default_llm()) print(f"Time Cutoff: {cutoff}") print(f"Favor Recent: {recency_bias}") diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index ca6978b57..cefb92341 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.db.engine import get_session +from danswer.llm.factory import get_default_llm from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep @@ -71,6 +72,7 @@ def gpt_search( query=search_request.query, ), user=None, + llm=get_default_llm(), db_session=db_session, ).reranked_chunks diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index ab72699c2..6152bc1d8 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -4,6 +4,7 @@ import uuid from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Request from fastapi import Response from fastapi import UploadFile from fastapi.responses import StreamingResponse @@ -41,6 +42,7 @@ from danswer.file_store.models import FileDescriptor from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) +from danswer.llm.headers import get_litellm_additional_request_headers from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, @@ -233,6 +235,7 @@ def delete_chat_session_by_id( @router.post("/send-message") def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, + request: Request, user: User | None = Depends(current_user), ) -> StreamingResponse: """This endpoint is both used for all the following purposes: @@ -256,6 +259,9 @@ def handle_new_chat_message( new_msg_req=chat_message_req, user=user, use_existing_user_message=chat_message_req.use_existing_user_message, + litellm_additional_headers=get_litellm_additional_request_headers( + request.headers + ), ) return StreamingResponse(packets, media_type="application/json") diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 968c17f5a..2a272b6ab 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -15,7 +15,6 @@ from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.interfaces import LLM -from danswer.llm.interfaces import LLMConfig from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType from danswer.search.models import IndexFilters @@ -63,7 +62,7 @@ class SearchTool(Tool): persona: Persona, retrieval_options: RetrievalDetails | None, prompt_config: PromptConfig, - llm_config: LLMConfig, + llm: LLM, pruning_config: DocumentPruningConfig, # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to @@ -76,7 +75,7 @@ class SearchTool(Tool): self.persona = persona self.retrieval_options = retrieval_options self.prompt_config = prompt_config - self.llm_config = llm_config + self.llm = llm self.pruning_config = pruning_config self.selected_docs = selected_docs @@ -175,7 +174,7 @@ class SearchTool(Tool): docs=self.selected_docs, doc_relevance_list=None, prompt_config=self.prompt_config, - llm_config=self.llm_config, + llm_config=self.llm.config, question=query, document_pruning_config=self.pruning_config, ), @@ -191,9 +190,9 @@ class SearchTool(Tool): search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, - human_selected_filters=self.retrieval_options.filters - if self.retrieval_options - else None, + human_selected_filters=( + self.retrieval_options.filters if self.retrieval_options else None + ), persona=self.persona, offset=self.retrieval_options.offset if self.retrieval_options @@ -204,6 +203,7 @@ class SearchTool(Tool): full_doc=self.full_doc, ), user=self.user, + llm=self.llm, db_session=self.db_session, ) yield ToolResponse( @@ -233,7 +233,7 @@ class SearchTool(Tool): for ind in range(len(llm_docs)) ], prompt_config=self.prompt_config, - llm_config=self.llm_config, + llm_config=self.llm.config, question=query, document_pruning_config=self.pruning_config, ) diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index 23eefc45c..e67fe7c5b 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_sqlalchemy_engine from danswer.llm.answering.doc_pruning import reorder_docs +from danswer.llm.factory import get_default_llm from danswer.search.models import InferenceChunk from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer @@ -87,6 +88,7 @@ def get_search_results( query=query, ), user=None, + llm=get_default_llm(), db_session=db_session, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric,