From 7d201f67d4e07f65916d2b474b0a3c7abdd39a54 Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 23 Aug 2024 11:13:36 -0700 Subject: [PATCH] Fix typing for custom tool response --- backend/danswer/chat/models.py | 3 ++- backend/danswer/tools/custom/base_tool_types.py | 2 ++ backend/danswer/tools/custom/custom_tool.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 backend/danswer/tools/custom/base_tool_types.py diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index d1da783b6..1828d5250 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -9,6 +9,7 @@ from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse +from danswer.tools.custom.base_tool_types import ToolResultType class LlmDoc(BaseModel): @@ -130,7 +131,7 @@ class ImageGenerationDisplay(BaseModel): class CustomToolResponse(BaseModel): - response: dict + response: ToolResultType tool_name: str diff --git a/backend/danswer/tools/custom/base_tool_types.py b/backend/danswer/tools/custom/base_tool_types.py new file mode 100644 index 000000000..7bef9a572 --- /dev/null +++ b/backend/danswer/tools/custom/base_tool_types.py @@ -0,0 +1,2 @@ +# should really be `JSON_ro`, but this causes issues with pydantic +ToolResultType = dict | list | str | int | float | bool diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index 84f10c3ec..f7cbf236f 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -11,6 +11,7 @@ from pydantic import BaseModel from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.tools.custom.base_tool_types import ToolResultType from danswer.tools.custom.custom_tool_prompts import ( SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, ) @@ -34,7 +35,7 @@ CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" class CustomToolCallSummary(BaseModel): tool_name: str - tool_result: dict + tool_result: ToolResultType class CustomTool(Tool):