From c6e8bf2d286b025cb485f84890223018d11fd55e Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sun, 3 Nov 2024 15:54:19 -0800 Subject: [PATCH] add multiple formats to tools (#3041) --- backend/danswer/chat/models.py | 4 +- backend/danswer/chat/process_message.py | 35 +++- backend/danswer/llm/utils.py | 18 ++- .../one_shot_answer/answer_question.py | 2 +- .../custom/custom_tool.py | 151 +++++++++++++++++- .../tool_implementations/custom/prompt.py | 25 +++ .../danswer/tools/custom/test_custom_tools.py | 1 + web/src/app/chat/ChatPage.tsx | 4 +- web/src/app/chat/interfaces.ts | 2 +- web/src/app/chat/lib.tsx | 4 +- 10 files changed, 220 insertions(+), 26 deletions(-) create mode 100644 backend/danswer/tools/tool_implementations/custom/prompt.py diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index d5925fc2e..159506c07 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -class ImageGenerationDisplay(BaseModel): +class FileChatDisplay(BaseModel): file_ids: list[str] @@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = ( | DanswerQuotes | CitationInfo | DanswerContexts - | ImageGenerationDisplay + | FileChatDisplay | CustomToolResponse | StreamingError | StreamStopInfo diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 252f6df0f..314e432b8 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -11,8 +11,8 @@ from danswer.chat.models import AllCitations from danswer.chat.models import CitationInfo from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import FileChatDisplay from danswer.chat.models import FinalUsedContextDocsResponse -from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import MessageResponseIDInfo from danswer.chat.models import MessageSpecificCitations @@ -275,7 +275,7 @@ ChatPacket = ( | DanswerAnswerPiece | AllCitations | CitationInfo - | ImageGenerationDisplay + | FileChatDisplay | CustomToolResponse | MessageSpecificCitations | MessageResponseIDInfo @@ -769,7 +769,6 @@ def stream_chat_message_objects( yield LLMRelevanceFilterResponse( llm_selected_doc_indices=llm_indices ) - elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: yield FinalUsedContextDocsResponse( final_context_docs=packet.response @@ -787,7 +786,7 @@ def stream_chat_message_objects( FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) for file_id in file_ids ] - yield ImageGenerationDisplay( + yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) elif packet.id == INTERNET_SEARCH_RESPONSE_ID: @@ -801,10 +800,30 @@ def stream_chat_message_objects( yield qa_docs_response elif packet.id == CUSTOM_TOOL_RESPONSE_ID: custom_tool_response = cast(CustomToolCallSummary, packet.response) - yield CustomToolResponse( - response=custom_tool_response.tool_result, - tool_name=custom_tool_response.tool_name, - ) + + if ( + custom_tool_response.response_type == "image" + or custom_tool_response.response_type == "csv" + ): + file_ids = custom_tool_response.tool_result.file_ids + ai_message_files = [ + FileDescriptor( + id=str(file_id), + type=ChatFileType.IMAGE + if custom_tool_response.response_type == "image" + else ChatFileType.CSV, + ) + for file_id in file_ids + ] + yield FileChatDisplay( + file_ids=[str(file_id) for file_id in file_ids] + ) + else: + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) + elif isinstance(packet, StreamStopInfo): pass else: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 77e137552..240171469 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -109,11 +109,10 @@ def translate_danswer_msg_to_langchain( files: list[InMemoryChatFile] = [] # If the message is a `ChatMessage`, it doesn't have the downloaded files - # attached. Just ignore them for now. Also, OpenAI doesn't allow files to - # be attached to AI messages, so we must remove them - if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT: + # attached. Just ignore them for now. + if not isinstance(msg, ChatMessage): files = msg.files - content = build_content_with_imgs(msg.message, files) + content = build_content_with_imgs(msg.message, files, message_type=msg.message_type) if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") @@ -188,10 +187,19 @@ def build_content_with_imgs( message: str, files: list[InMemoryChatFile] | None = None, img_urls: list[str] | None = None, + message_type: MessageType = MessageType.USER, ) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type files = files or [] - img_files = [file for file in files if file.file_type == ChatFileType.IMAGE] + + # Only include image files for user messages + img_files = ( + [file for file in files if file.file_type == ChatFileType.IMAGE] + if message_type == MessageType.USER + else [] + ) + img_urls = img_urls or [] + message_main_content = _build_content(message, files) if not img_files and not img_urls: diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 9ece5f4bb..f3cbe2b60 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -253,7 +253,7 @@ def stream_answer_objects( return_contexts=query_req.return_contexts, skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation, ) - # won't be any ImageGenerationDisplay responses since that tool is never passed in + # won't be any FileChatDisplay responses since that tool is never passed in for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): # for one-shot flow, don't currently do anything with these if isinstance(packet, ToolResponse): diff --git a/backend/danswer/tools/tool_implementations/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py index a1fb4bb69..eace6d53a 100644 --- a/backend/danswer/tools/tool_implementations/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -1,22 +1,34 @@ +import csv import json +import uuid from collections.abc import Generator +from io import BytesIO +from io import StringIO from typing import Any from typing import cast +from typing import Dict +from typing import List import requests from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from pydantic import BaseModel +from danswer.configs.constants import FileOrigin +from danswer.db.engine import get_session_with_tenant +from danswer.file_store.file_store import get_default_file_store +from danswer.file_store.models import ChatFileType +from danswer.file_store.models import InMemoryChatFile from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.prompts.build import AnswerPromptBuilder from danswer.llm.interfaces import LLM from danswer.tools.base_tool import BaseTool +from danswer.tools.message import ToolCallSummary from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER from danswer.tools.models import DynamicSchemaInfo from danswer.tools.models import MESSAGE_ID_PLACEHOLDER from danswer.tools.models import ToolResponse -from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType from danswer.tools.tool_implementations.custom.custom_tool_prompts import ( SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, ) @@ -39,6 +51,9 @@ from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BO from danswer.tools.tool_implementations.custom.openapi_parsing import ( validate_openapi_schema, ) +from danswer.tools.tool_implementations.custom.prompt import ( + build_custom_image_generation_user_prompt, +) from danswer.utils.headers import header_list_to_header_dict from danswer.utils.headers import HeaderItemDict from danswer.utils.logger import setup_logger @@ -48,9 +63,14 @@ logger = setup_logger() CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" +class CustomToolFileResponse(BaseModel): + file_ids: List[str] # References to saved images or CSVs + + class CustomToolCallSummary(BaseModel): tool_name: str - tool_result: ToolResultType + response_type: str # e.g., 'json', 'image', 'csv', 'graph' + tool_result: Any # The response data class CustomTool(BaseTool): @@ -91,6 +111,12 @@ class CustomTool(BaseTool): self, *args: ToolResponse ) -> str | list[str | dict[str, Any]]: response = cast(CustomToolCallSummary, args[0].response) + + if response.response_type == "image" or response.response_type == "csv": + image_response = cast(CustomToolFileResponse, response.tool_result) + return json.dumps({"file_ids": image_response.file_ids}) + + # For JSON or other responses, return as-is return json.dumps(response.tool_result) """For LLMs which do NOT support explicit tool calling""" @@ -158,6 +184,38 @@ class CustomTool(BaseTool): ) return None + def _save_and_get_file_references( + self, file_content: bytes | str, content_type: str + ) -> List[str]: + with get_session_with_tenant() as db_session: + file_store = get_default_file_store(db_session) + + file_id = str(uuid.uuid4()) + + # Handle both binary and text content + if isinstance(file_content, str): + content = BytesIO(file_content.encode()) + else: + content = BytesIO(file_content) + + file_store.save_file( + file_name=file_id, + content=content, + display_name=file_id, + file_origin=FileOrigin.CHAT_UPLOAD, + file_type=content_type, + file_metadata={ + "content_type": content_type, + }, + ) + + return [file_id] + + def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]: + csv_file = StringIO(csv_text) + reader = csv.DictReader(csv_file) + return [row for row in reader] + """Actual execution of the tool""" def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: @@ -177,20 +235,103 @@ class CustomTool(BaseTool): url = self._method_spec.build_url(self._base_url, path_params, query_params) method = self._method_spec.method - # Log request details + response = requests.request( method, url, json=request_body, headers=self.headers ) + content_type = response.headers.get("Content-Type", "") + + if "text/csv" in content_type: + file_ids = self._save_and_get_file_references( + response.content, content_type + ) + tool_result = CustomToolFileResponse(file_ids=file_ids) + response_type = "csv" + + elif "image/" in content_type: + file_ids = self._save_and_get_file_references( + response.content, content_type + ) + tool_result = CustomToolFileResponse(file_ids=file_ids) + response_type = "image" + + else: + tool_result = response.json() + response_type = "json" + + logger.info( + f"Returning tool response for {self._name} with type {response_type}" + ) yield ToolResponse( id=CUSTOM_TOOL_RESPONSE_ID, response=CustomToolCallSummary( - tool_name=self._name, tool_result=response.json() + tool_name=self._name, + response_type=response_type, + tool_result=tool_result, ), ) + def build_next_prompt( + self, + prompt_builder: AnswerPromptBuilder, + tool_call_summary: ToolCallSummary, + tool_responses: list[ToolResponse], + using_tool_calling_llm: bool, + ) -> AnswerPromptBuilder: + response = cast(CustomToolCallSummary, tool_responses[0].response) + + # Handle non-file responses using parent class behavior + if response.response_type not in ["image", "csv"]: + return super().build_next_prompt( + prompt_builder, + tool_call_summary, + tool_responses, + using_tool_calling_llm, + ) + + # Handle image and CSV file responses + file_type = ( + ChatFileType.IMAGE + if response.response_type == "image" + else ChatFileType.CSV + ) + + # Load files from storage + files = [] + with get_session_with_tenant() as db_session: + file_store = get_default_file_store(db_session) + + for file_id in response.tool_result.file_ids: + try: + file_io = file_store.read_file(file_id, mode="b") + files.append( + InMemoryChatFile( + file_id=file_id, + filename=file_id, + content=file_io.read(), + file_type=file_type, + ) + ) + except Exception: + logger.exception(f"Failed to read file {file_id}") + + # Update prompt with file content + prompt_builder.update_user_prompt( + build_custom_image_generation_user_prompt( + query=prompt_builder.get_user_message_content(), + files=files, + file_type=file_type, + ) + ) + + return prompt_builder + def final_result(self, *args: ToolResponse) -> JSON_ro: - return cast(CustomToolCallSummary, args[0].response).tool_result + response = cast(CustomToolCallSummary, args[0].response) + if isinstance(response.tool_result, CustomToolFileResponse): + return response.tool_result.model_dump() + return response.tool_result def build_custom_tools_from_openapi_schema_and_headers( diff --git a/backend/danswer/tools/tool_implementations/custom/prompt.py b/backend/danswer/tools/tool_implementations/custom/prompt.py new file mode 100644 index 000000000..9911594a9 --- /dev/null +++ b/backend/danswer/tools/tool_implementations/custom/prompt.py @@ -0,0 +1,25 @@ +from langchain_core.messages import HumanMessage + +from danswer.file_store.models import ChatFileType +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.utils import build_content_with_imgs + + +CUSTOM_IMG_GENERATION_SUMMARY_PROMPT = """ +You have just created the attached {file_type} file in response to the following query: "{query}". + +Can you please summarize it in a sentence or two? Do NOT include image urls or bulleted lists. +""" + + +def build_custom_image_generation_user_prompt( + query: str, file_type: ChatFileType, files: list[InMemoryChatFile] | None = None +) -> HumanMessage: + return HumanMessage( + content=build_content_with_imgs( + message=CUSTOM_IMG_GENERATION_SUMMARY_PROMPT.format( + query=query, file_type=file_type.value + ).strip(), + files=files, + ) + ) diff --git a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py index f56336809..4d47a8761 100644 --- a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py +++ b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py @@ -215,6 +215,7 @@ class TestCustomTool(unittest.TestCase): mock_response = ToolResponse( id=CUSTOM_TOOL_RESPONSE_ID, response=CustomToolCallSummary( + response_type="json", tool_name="getAssistant", tool_result={"id": "789", "name": "Final Assistant"}, ), diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index dcf1baf14..43ce519a1 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -10,7 +10,7 @@ import { ChatSessionSharedStatus, DocumentsResponse, FileDescriptor, - ImageGenerationDisplay, + FileChatDisplay, Message, MessageResponseIDInfo, RetrievalType, @@ -1281,7 +1281,7 @@ export function ChatPage({ query = toolCall.tool_args["query"]; } } else if (Object.hasOwn(packet, "file_ids")) { - aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( + aiMessageImages = (packet as FileChatDisplay).file_ids.map( (fileId) => { return { id: fileId, diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 20ea4e7d2..dd7368374 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -136,7 +136,7 @@ export interface DocumentsResponse { rephrased_query: string | null; } -export interface ImageGenerationDisplay { +export interface FileChatDisplay { file_ids: string[]; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index cda037a08..41a83eee1 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -12,7 +12,7 @@ import { ChatSession, DocumentsResponse, FileDescriptor, - ImageGenerationDisplay, + FileChatDisplay, Message, MessageResponseIDInfo, RetrievalType, @@ -103,7 +103,7 @@ export type PacketType = | BackendMessage | AnswerPiecePacket | DocumentsResponse - | ImageGenerationDisplay + | FileChatDisplay | StreamingError | MessageResponseIDInfo | StreamStopInfo;