mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Add custom tool headers (#2773)
* add custom tool headers * simplify * k * k * k * nit
This commit is contained in:
parent
f23a89ccfd
commit
11372aac8f
@ -1,6 +1,7 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
@ -166,3 +167,31 @@ def reorganize_citations(
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
||||
|
||||
def extract_headers(
|
||||
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Extract headers specified in pass_through_headers from input headers.
|
||||
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
|
||||
|
||||
Args:
|
||||
headers: Input headers as dict or Headers object.
|
||||
|
||||
Returns:
|
||||
dict: Filtered headers based on pass_through_headers.
|
||||
"""
|
||||
if not pass_through_headers:
|
||||
return {}
|
||||
|
||||
extracted_headers: dict[str, str] = {}
|
||||
for key in pass_through_headers:
|
||||
if key in headers:
|
||||
extracted_headers[key] = headers[key]
|
||||
else:
|
||||
# fastapi makes all header keys lowercase, handling that here
|
||||
lowercase_key = key.lower()
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
@ -276,6 +276,7 @@ def stream_chat_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
) -> ChatPacketStream:
|
||||
@ -862,6 +863,7 @@ def stream_chat_message(
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
@ -871,6 +873,7 @@ def stream_chat_message(
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
tool_additional_headers=tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
|
@ -119,3 +119,19 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
|
||||
# List of headers to pass through to tool calls (e.g., API requests made by tools)
|
||||
# This allows for dynamic configuration of tool behavior based on incoming request headers
|
||||
TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS")
|
||||
if _TOOL_PASS_THROUGH_HEADERS_RAW:
|
||||
try:
|
||||
TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW)
|
||||
except Exception:
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
logger.error(
|
||||
"Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
@ -1,26 +1,4 @@
|
||||
from fastapi.datastructures import Headers
|
||||
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
|
||||
|
||||
def get_litellm_additional_request_headers(
|
||||
headers: dict[str, str] | Headers
|
||||
) -> dict[str, str]:
|
||||
if not LITELLM_PASS_THROUGH_HEADERS:
|
||||
return {}
|
||||
|
||||
pass_through_headers: dict[str, str] = {}
|
||||
for key in LITELLM_PASS_THROUGH_HEADERS:
|
||||
if key in headers:
|
||||
pass_through_headers[key] = headers[key]
|
||||
else:
|
||||
# fastapi makes all header keys lowercase, handling that here
|
||||
lowercase_key = key.lower()
|
||||
if lowercase_key in headers:
|
||||
pass_through_headers[lowercase_key] = headers[lowercase_key]
|
||||
|
||||
return pass_through_headers
|
||||
|
||||
|
||||
def build_llm_extra_headers(
|
||||
|
@ -18,10 +18,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import delete_chat_session
|
||||
@ -50,7 +53,6 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
@ -229,7 +231,9 @@ def rename_chat_session(
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms(
|
||||
additional_headers=get_litellm_additional_request_headers(request.headers)
|
||||
additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
)
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
# This may be longer than what the LLM tends to produce but is the most
|
||||
@ -330,8 +334,11 @@ def handle_new_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
tool_additional_headers=extract_headers(
|
||||
request.headers, TOOL_PASS_THROUGH_HEADERS
|
||||
),
|
||||
is_connected=is_disconnected_func,
|
||||
):
|
||||
|
@ -47,6 +47,7 @@ class CustomTool(Tool):
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[dict[str, str]] | None = [],
|
||||
tool_additional_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
@ -54,11 +55,9 @@ class CustomTool(Tool):
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
self.headers = (
|
||||
{header["key"]: header["value"] for header in custom_headers}
|
||||
if custom_headers
|
||||
else {}
|
||||
)
|
||||
self.headers = {
|
||||
header["key"]: header["value"] for header in (custom_headers or [])
|
||||
} | (tool_additional_headers or {})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -185,6 +184,7 @@ class CustomTool(Tool):
|
||||
|
||||
def build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema: dict[str, Any],
|
||||
tool_additional_headers: dict[str, str] | None = None,
|
||||
custom_headers: list[dict[str, str]] | None = [],
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
) -> list[CustomTool]:
|
||||
@ -205,7 +205,8 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [
|
||||
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
|
||||
CustomTool(method_spec, url, custom_headers, tool_additional_headers)
|
||||
for method_spec in method_specs
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user