diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index d1da783b6f..1828d52508 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -9,6 +9,7 @@ from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse +from danswer.tools.custom.base_tool_types import ToolResultType class LlmDoc(BaseModel): @@ -130,7 +131,7 @@ class ImageGenerationDisplay(BaseModel): class CustomToolResponse(BaseModel): - response: dict + response: ToolResultType tool_name: str diff --git a/backend/danswer/tools/custom/base_tool_types.py b/backend/danswer/tools/custom/base_tool_types.py new file mode 100644 index 0000000000..7bef9a572c --- /dev/null +++ b/backend/danswer/tools/custom/base_tool_types.py @@ -0,0 +1,2 @@ +# should really be `JSON_ro`, but this causes issues with pydantic +ToolResultType = dict | list | str | int | float | bool diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index 84f10c3ec0..f7cbf236f2 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -11,6 +11,7 @@ from pydantic import BaseModel from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.tools.custom.base_tool_types import ToolResultType from danswer.tools.custom.custom_tool_prompts import ( SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, ) @@ -34,7 +35,7 @@ CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" class CustomToolCallSummary(BaseModel): tool_name: str - tool_result: dict + tool_result: ToolResultType class CustomTool(Tool):