Add support for passthrough auth for custom tool calls (#2824)

* Add support for passthrough auth for custom tool calls

* Fix formatting
This commit is contained in:
Chris Weaver 2024-10-16 15:50:16 -07:00 committed by GitHub
parent db0779dd02
commit 33974fc12c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 134 additions and 50 deletions

View File

@ -105,6 +105,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -276,7 +277,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
@ -640,7 +641,12 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
@ -863,7 +869,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
@ -873,7 +879,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
tool_additional_headers=tool_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:

View File

@ -119,19 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# List of headers to pass through to tool calls (e.g., API requests made by tools)
# This allows for dynamic configuration of tool behavior based on incoming request headers
TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS")
if _TOOL_PASS_THROUGH_HEADERS_RAW:
try:
TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW)
except Exception:
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@ -0,0 +1,22 @@
import json
import os
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
)
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
try:
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@ -60,6 +60,7 @@ from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
@ -1288,7 +1289,7 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.

View File

@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from uuid import UUID
from sqlalchemy import select
@ -6,6 +7,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -67,7 +69,9 @@ def update_tool(
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
db_session.commit()
return tool

View File

@ -7,9 +7,9 @@ from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers
def get_main_llm_from_tuple(

View File

@ -1,12 +0,0 @@
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers

View File

@ -25,7 +25,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
@ -74,6 +73,7 @@ from danswer.server.query_and_chat.models import RenameChatSessionResponse
from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.headers import get_custom_tool_additional_request_headers
from danswer.utils.logger import setup_logger
@ -338,8 +338,8 @@ def handle_new_chat_message(
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
tool_additional_headers=extract_headers(
request.headers, TOOL_PASS_THROUGH_HEADERS
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,
):

View File

@ -29,6 +29,8 @@ from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -46,8 +48,7 @@ class CustomTool(Tool):
self,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
tool_additional_headers: dict[str, str] | None = None,
custom_headers: list[HeaderItemDict] | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
@ -55,9 +56,9 @@ class CustomTool(Tool):
self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = {
header["key"]: header["value"] for header in (custom_headers or [])
} | (tool_additional_headers or {})
self.headers = (
header_list_to_header_dict(custom_headers) if custom_headers else {}
)
@property
def name(self) -> str:
@ -184,8 +185,7 @@ class CustomTool(Tool):
def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
tool_additional_headers: dict[str, str] | None = None,
custom_headers: list[dict[str, str]] | None = [],
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
@ -205,8 +205,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers, tool_additional_headers)
for method_spec in method_specs
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]

View File

@ -11,13 +11,13 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@ -0,0 +1,79 @@
from typing import TypedDict
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.tool_configs import CUSTOM_TOOL_PASS_THROUGH_HEADERS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class HeaderItemDict(TypedDict):
key: str
value: str
def clean_header_list(headers_to_clean: list[HeaderItemDict]) -> dict[str, str]:
cleaned_headers: dict[str, str] = {}
for item in headers_to_clean:
key = item["key"]
value = item["value"]
if key in cleaned_headers:
logger.warning(
f"Duplicate header {key} found in custom headers, ignoring..."
)
continue
cleaned_headers[key] = value
return cleaned_headers
def header_dict_to_header_list(header_dict: dict[str, str]) -> list[HeaderItemDict]:
return [{"key": key, "value": value} for key, value in header_dict.items()]
def header_list_to_header_dict(header_list: list[HeaderItemDict]) -> dict[str, str]:
return {header["key"]: header["value"] for header in header_list}
def get_relevant_headers(
headers: dict[str, str] | Headers, desired_headers: list[str] | None
) -> dict[str, str]:
if not desired_headers:
return {}
pass_through_headers: dict[str, str] = {}
for key in desired_headers:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers
def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
return get_relevant_headers(headers, LITELLM_PASS_THROUGH_HEADERS)
def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers
def get_custom_tool_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
return get_relevant_headers(headers, CUSTOM_TOOL_PASS_THROUGH_HEADERS)

View File

@ -13,6 +13,7 @@ from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import HeaderItemDict
class TestCustomTool(unittest.TestCase):
@ -143,7 +144,7 @@ class TestCustomTool(unittest.TestCase):
Test the custom tool with custom headers.
Verifies that the tool correctly includes the custom headers in the request.
"""
custom_headers: list[dict[str, str]] = [
custom_headers: list[HeaderItemDict] = [
{"key": "Authorization", "value": "Bearer token123"},
{"key": "Custom-Header", "value": "CustomValue"},
]
@ -171,7 +172,7 @@ class TestCustomTool(unittest.TestCase):
Test the custom tool with an empty list of custom headers.
Verifies that the tool correctly handles an empty list of headers.
"""
custom_headers: list[dict[str, str]] = []
custom_headers: list[HeaderItemDict] = []
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,