Add custom tool headers (#2773)

* add custom tool headers

* simplify

* k

* k

* k

* nit
This commit is contained in:
pablodanswer 2024-10-15 21:37:00 -07:00 committed by GitHub
parent f23a89ccfd
commit 11372aac8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 66 additions and 32 deletions

View File

@ -1,6 +1,7 @@
import re
from typing import cast
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
@ -166,3 +167,31 @@ def reorganize_citations(
new_citation_info[citation.citation_num] = citation
return new_answer, list(new_citation_info.values())
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:
"""
Extract headers specified in pass_through_headers from input headers.
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
Args:
headers: Input headers as dict or Headers object.
Returns:
dict: Filtered headers based on pass_through_headers.
"""
if not pass_through_headers:
return {}
extracted_headers: dict[str, str] = {}
for key in pass_through_headers:
if key in headers:
extracted_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers

View File

@ -276,6 +276,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
@ -862,6 +863,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
@ -871,6 +873,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
tool_additional_headers=tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:

View File

@ -119,3 +119,19 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# List of headers to pass through to tool calls (e.g., API requests made by tools)
# This allows for dynamic configuration of tool behavior based on incoming request headers
TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS")
if _TOOL_PASS_THROUGH_HEADERS_RAW:
try:
TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW)
except Exception:
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@ -1,26 +1,4 @@
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}
pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers
def build_llm_extra_headers(

View File

@ -18,10 +18,13 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import extract_headers
from danswer.chat.process_message import stream_chat_message
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
@ -50,7 +53,6 @@ from danswer.llm.answering.prompts.citations_prompt import (
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
@ -229,7 +231,9 @@ def rename_chat_session(
try:
llm, _ = get_default_llms(
additional_headers=get_litellm_additional_request_headers(request.headers)
additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
)
)
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
@ -330,8 +334,11 @@ def handle_new_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
tool_additional_headers=extract_headers(
request.headers, TOOL_PASS_THROUGH_HEADERS
),
is_connected=is_disconnected_func,
):

View File

@ -47,6 +47,7 @@ class CustomTool(Tool):
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
tool_additional_headers: dict[str, str] | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
@ -54,11 +55,9 @@ class CustomTool(Tool):
self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = (
{header["key"]: header["value"] for header in custom_headers}
if custom_headers
else {}
)
self.headers = {
header["key"]: header["value"] for header in (custom_headers or [])
} | (tool_additional_headers or {})
@property
def name(self) -> str:
@ -185,6 +184,7 @@ class CustomTool(Tool):
def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
tool_additional_headers: dict[str, str] | None = None,
custom_headers: list[dict[str, str]] | None = [],
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
@ -205,7 +205,8 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
CustomTool(method_spec, url, custom_headers, tool_additional_headers)
for method_spec in method_specs
]