mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
Ability to pass through headers to LLM call
This commit is contained in:
parent
180b592afe
commit
b723627e0c
@ -193,6 +193,7 @@ def stream_chat_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@ -228,7 +229,9 @@ def stream_chat_message_objects(
|
||||
|
||||
try:
|
||||
llm = get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
@ -410,7 +413,7 @@ def stream_chat_message_objects(
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm.config,
|
||||
llm=llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
@ -455,7 +458,9 @@ def stream_chat_message_objects(
|
||||
llm=(
|
||||
llm
|
||||
or get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
@ -576,6 +581,7 @@ def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@ -583,6 +589,7 @@ def stream_chat_message(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
|
@ -100,7 +100,7 @@ DISABLE_LITELLM_STREAMING = (
|
||||
).lower() == "true"
|
||||
|
||||
# extra headers to pass to LiteLLM
|
||||
LITELLM_EXTRA_HEADERS = None
|
||||
LITELLM_EXTRA_HEADERS: dict[str, str] | None = None
|
||||
_LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS")
|
||||
if _LITELLM_EXTRA_HEADERS_RAW:
|
||||
try:
|
||||
@ -113,3 +113,18 @@ if _LITELLM_EXTRA_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
# if specified, will pass through request headers to the call to the LLM
|
||||
LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS")
|
||||
if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
try:
|
||||
LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW)
|
||||
except Exception:
|
||||
# need to import here to avoid circular imports
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
@ -13,7 +13,9 @@ from danswer.llm.override_models import LLMOverride
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
persona: Persona, llm_override: LLMOverride | None = None
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
@ -25,6 +27,7 @@ def get_llm_for_persona(
|
||||
),
|
||||
model_version=(model_version_override or persona.llm_model_version_override),
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
|
||||
@ -34,6 +37,7 @@ def get_default_llm(
|
||||
use_fast_llm: bool = False,
|
||||
model_provider_name: str | None = None,
|
||||
model_version: str | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
@ -65,6 +69,7 @@ def get_default_llm(
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
|
||||
@ -77,7 +82,14 @@ def get_llm(
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
extra_headers = {}
|
||||
if additional_headers:
|
||||
extra_headers.update(additional_headers)
|
||||
if LITELLM_EXTRA_HEADERS:
|
||||
extra_headers.update(LITELLM_EXTRA_HEADERS)
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@ -87,5 +99,5 @@ def get_llm(
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=LITELLM_EXTRA_HEADERS,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
22
backend/danswer/llm/headers.py
Normal file
22
backend/danswer/llm/headers.py
Normal file
@ -0,0 +1,22 @@
|
||||
from fastapi.datastructures import Headers
|
||||
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
|
||||
|
||||
def get_litellm_additional_request_headers(
|
||||
headers: dict[str, str] | Headers
|
||||
) -> dict[str, str]:
|
||||
if not LITELLM_PASS_THROUGH_HEADERS:
|
||||
return {}
|
||||
|
||||
pass_through_headers: dict[str, str] = {}
|
||||
for key in LITELLM_PASS_THROUGH_HEADERS:
|
||||
if key in headers:
|
||||
pass_through_headers[key] = headers[key]
|
||||
else:
|
||||
# fastapi makes all header keys lowercase, handling that here
|
||||
lowercase_key = key.lower()
|
||||
if lowercase_key in headers:
|
||||
pass_through_headers[lowercase_key] = headers[lowercase_key]
|
||||
|
||||
return pass_through_headers
|
@ -172,7 +172,7 @@ def stream_answer_objects(
|
||||
persona=chat_session.persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm.config,
|
||||
llm=llm,
|
||||
pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import IndexFilters
|
||||
@ -54,6 +55,7 @@ class SearchPipeline:
|
||||
self,
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
@ -62,6 +64,7 @@ class SearchPipeline:
|
||||
):
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@ -229,6 +232,7 @@ class SearchPipeline:
|
||||
) = retrieval_preprocessing(
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
@ -31,6 +32,7 @@ logger = setup_logger()
|
||||
def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
include_query_intent: bool = True,
|
||||
@ -87,14 +89,14 @@ def retrieval_preprocessing(
|
||||
# Based on the query figure out if we should apply any hard time filters /
|
||||
# if we should bias more recent docs even more strongly
|
||||
run_time_filters = (
|
||||
FunctionCall(extract_time_filter, (query,), {})
|
||||
FunctionCall(extract_time_filter, (query, llm), {})
|
||||
if auto_detect_time_filter
|
||||
else None
|
||||
)
|
||||
|
||||
# Based on the query, figure out if we should apply any source filters
|
||||
run_source_filters = (
|
||||
FunctionCall(extract_source_filter, (query, db_session), {})
|
||||
FunctionCall(extract_source_filter, (query, llm, db_session), {})
|
||||
if auto_detect_source_filter
|
||||
else None
|
||||
)
|
||||
|
@ -6,8 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
@ -44,7 +43,7 @@ def _sample_document_sources(
|
||||
|
||||
|
||||
def extract_source_filter(
|
||||
query: str, db_session: Session
|
||||
query: str, llm: LLM, db_session: Session
|
||||
) -> list[DocumentSource] | None:
|
||||
"""Returns a list of valid sources for search or None if no specific sources were detected"""
|
||||
|
||||
@ -147,11 +146,6 @@ def extract_source_filter(
|
||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||
return None
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return None
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
@ -165,9 +159,11 @@ def extract_source_filter(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from danswer.llm.factory import get_default_llm
|
||||
|
||||
# Just for testing purposes
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
while True:
|
||||
user_input = input("Query to Extract Sources: ")
|
||||
sources = extract_source_filter(user_input, db_session)
|
||||
sources = extract_source_filter(user_input, get_default_llm(), db_session)
|
||||
print(sources)
|
||||
|
@ -5,8 +5,7 @@ from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
|
||||
@ -41,7 +40,7 @@ def best_match_time(time_str: str) -> datetime | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]:
|
||||
"""Returns a datetime if a hard time filter should be applied for the given query
|
||||
Additionally returns a bool, True if more recently updated Documents should be
|
||||
heavily favored"""
|
||||
@ -147,11 +146,6 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
|
||||
return None, False
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return None, False
|
||||
|
||||
messages = _get_time_filter_messages(query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
@ -162,8 +156,10 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Just for testing purposes, too tedious to unit test as it relies on an LLM
|
||||
from danswer.llm.factory import get_default_llm
|
||||
|
||||
while True:
|
||||
user_input = input("Query to Extract Time: ")
|
||||
cutoff, recency_bias = extract_time_filter(user_input)
|
||||
cutoff, recency_bias = extract_time_filter(user_input, get_default_llm())
|
||||
print(f"Time Cutoff: {cutoff}")
|
||||
print(f"Favor Recent: {recency_bias}")
|
||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
@ -71,6 +72,7 @@ def gpt_search(
|
||||
query=search_request.query,
|
||||
),
|
||||
user=None,
|
||||
llm=get_default_llm(),
|
||||
db_session=db_session,
|
||||
).reranked_chunks
|
||||
|
||||
|
@ -4,6 +4,7 @@ import uuid
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -41,6 +42,7 @@ from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
@ -233,6 +235,7 @@ def delete_chat_session_by_id(
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""This endpoint is both used for all the following purposes:
|
||||
@ -256,6 +259,9 @@ def handle_new_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
)
|
||||
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
|
@ -15,7 +15,6 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import IndexFilters
|
||||
@ -63,7 +62,7 @@ class SearchTool(Tool):
|
||||
persona: Persona,
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
llm: LLM,
|
||||
pruning_config: DocumentPruningConfig,
|
||||
# if specified, will not actually run a search and will instead return these
|
||||
# sections. Used when the user selects specific docs to talk to
|
||||
@ -76,7 +75,7 @@ class SearchTool(Tool):
|
||||
self.persona = persona
|
||||
self.retrieval_options = retrieval_options
|
||||
self.prompt_config = prompt_config
|
||||
self.llm_config = llm_config
|
||||
self.llm = llm
|
||||
self.pruning_config = pruning_config
|
||||
|
||||
self.selected_docs = selected_docs
|
||||
@ -175,7 +174,7 @@ class SearchTool(Tool):
|
||||
docs=self.selected_docs,
|
||||
doc_relevance_list=None,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm_config,
|
||||
llm_config=self.llm.config,
|
||||
question=query,
|
||||
document_pruning_config=self.pruning_config,
|
||||
),
|
||||
@ -191,9 +190,9 @@ class SearchTool(Tool):
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
human_selected_filters=self.retrieval_options.filters
|
||||
if self.retrieval_options
|
||||
else None,
|
||||
human_selected_filters=(
|
||||
self.retrieval_options.filters if self.retrieval_options else None
|
||||
),
|
||||
persona=self.persona,
|
||||
offset=self.retrieval_options.offset
|
||||
if self.retrieval_options
|
||||
@ -204,6 +203,7 @@ class SearchTool(Tool):
|
||||
full_doc=self.full_doc,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
yield ToolResponse(
|
||||
@ -233,7 +233,7 @@ class SearchTool(Tool):
|
||||
for ind in range(len(llm_docs))
|
||||
],
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm_config,
|
||||
llm_config=self.llm.config,
|
||||
question=query,
|
||||
document_pruning_config=self.pruning_config,
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.answering.doc_pruning import reorder_docs
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
@ -87,6 +88,7 @@ def get_search_results(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=get_default_llm(),
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
|
Loading…
x
Reference in New Issue
Block a user