mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-29 17:19:36 +02:00
add multiple formats to tools (#3041)
This commit is contained in:
parent
c2d04f591d
commit
c6e8bf2d28
@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class ImageGenerationDisplay(BaseModel):
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
|
@ -11,8 +11,8 @@ from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
@ -275,7 +275,7 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| AllCitations
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
@ -769,7 +769,6 @@ def stream_chat_message_objects(
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
@ -787,7 +786,7 @@ def stream_chat_message_objects(
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
@ -801,10 +800,30 @@ def stream_chat_message_objects(
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
else:
|
||||
|
@ -109,11 +109,10 @@ def translate_danswer_msg_to_langchain(
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# be attached to AI messages, so we must remove them
|
||||
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@ -188,10 +187,19 @@ def build_content_with_imgs(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
|
||||
|
||||
# Only include image files for user messages
|
||||
img_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.IMAGE]
|
||||
if message_type == MessageType.USER
|
||||
else []
|
||||
)
|
||||
|
||||
img_urls = img_urls or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if not img_files and not img_urls:
|
||||
|
@ -253,7 +253,7 @@ def stream_answer_objects(
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
# won't be any FileChatDisplay responses since that tool is never passed in
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
|
@ -1,22 +1,34 @@
|
||||
import csv
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.key_value_store.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
@ -39,6 +51,9 @@ from danswer.tools.tool_implementations.custom.openapi_parsing import REQUEST_BO
|
||||
from danswer.tools.tool_implementations.custom.openapi_parsing import (
|
||||
validate_openapi_schema,
|
||||
)
|
||||
from danswer.tools.tool_implementations.custom.prompt import (
|
||||
build_custom_image_generation_user_prompt,
|
||||
)
|
||||
from danswer.utils.headers import header_list_to_header_dict
|
||||
from danswer.utils.headers import HeaderItemDict
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -48,9 +63,14 @@ logger = setup_logger()
|
||||
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
|
||||
|
||||
|
||||
class CustomToolFileResponse(BaseModel):
|
||||
file_ids: List[str] # References to saved images or CSVs
|
||||
|
||||
|
||||
class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
tool_result: ToolResultType
|
||||
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
@ -91,6 +111,12 @@ class CustomTool(BaseTool):
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
response = cast(CustomToolCallSummary, args[0].response)
|
||||
|
||||
if response.response_type == "image" or response.response_type == "csv":
|
||||
image_response = cast(CustomToolFileResponse, response.tool_result)
|
||||
return json.dumps({"file_ids": image_response.file_ids})
|
||||
|
||||
# For JSON or other responses, return as-is
|
||||
return json.dumps(response.tool_result)
|
||||
|
||||
"""For LLMs which do NOT support explicit tool calling"""
|
||||
@ -158,6 +184,38 @@ class CustomTool(BaseTool):
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Handle both binary and text content
|
||||
if isinstance(file_content, str):
|
||||
content = BytesIO(file_content.encode())
|
||||
else:
|
||||
content = BytesIO(file_content)
|
||||
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=content,
|
||||
display_name=file_id,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=content_type,
|
||||
file_metadata={
|
||||
"content_type": content_type,
|
||||
},
|
||||
)
|
||||
|
||||
return [file_id]
|
||||
|
||||
def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]:
|
||||
csv_file = StringIO(csv_text)
|
||||
reader = csv.DictReader(csv_file)
|
||||
return [row for row in reader]
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
@ -177,20 +235,103 @@ class CustomTool(BaseTool):
|
||||
|
||||
url = self._method_spec.build_url(self._base_url, path_params, query_params)
|
||||
method = self._method_spec.method
|
||||
# Log request details
|
||||
|
||||
response = requests.request(
|
||||
method, url, json=request_body, headers=self.headers
|
||||
)
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
|
||||
if "text/csv" in content_type:
|
||||
file_ids = self._save_and_get_file_references(
|
||||
response.content, content_type
|
||||
)
|
||||
tool_result = CustomToolFileResponse(file_ids=file_ids)
|
||||
response_type = "csv"
|
||||
|
||||
elif "image/" in content_type:
|
||||
file_ids = self._save_and_get_file_references(
|
||||
response.content, content_type
|
||||
)
|
||||
tool_result = CustomToolFileResponse(file_ids=file_ids)
|
||||
response_type = "image"
|
||||
|
||||
else:
|
||||
tool_result = response.json()
|
||||
response_type = "json"
|
||||
|
||||
logger.info(
|
||||
f"Returning tool response for {self._name} with type {response_type}"
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
tool_name=self._name, tool_result=response.json()
|
||||
tool_name=self._name,
|
||||
response_type=response_type,
|
||||
tool_result=tool_result,
|
||||
),
|
||||
)
|
||||
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
tool_call_summary: ToolCallSummary,
|
||||
tool_responses: list[ToolResponse],
|
||||
using_tool_calling_llm: bool,
|
||||
) -> AnswerPromptBuilder:
|
||||
response = cast(CustomToolCallSummary, tool_responses[0].response)
|
||||
|
||||
# Handle non-file responses using parent class behavior
|
||||
if response.response_type not in ["image", "csv"]:
|
||||
return super().build_next_prompt(
|
||||
prompt_builder,
|
||||
tool_call_summary,
|
||||
tool_responses,
|
||||
using_tool_calling_llm,
|
||||
)
|
||||
|
||||
# Handle image and CSV file responses
|
||||
file_type = (
|
||||
ChatFileType.IMAGE
|
||||
if response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
)
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
|
||||
# Update prompt with file content
|
||||
prompt_builder.update_user_prompt(
|
||||
build_custom_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
files=files,
|
||||
file_type=file_type,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
return cast(CustomToolCallSummary, args[0].response).tool_result
|
||||
response = cast(CustomToolCallSummary, args[0].response)
|
||||
if isinstance(response.tool_result, CustomToolFileResponse):
|
||||
return response.tool_result.model_dump()
|
||||
return response.tool_result
|
||||
|
||||
|
||||
def build_custom_tools_from_openapi_schema_and_headers(
|
||||
|
25
backend/danswer/tools/tool_implementations/custom/prompt.py
Normal file
25
backend/danswer/tools/tool_implementations/custom/prompt.py
Normal file
@ -0,0 +1,25 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
|
||||
|
||||
CUSTOM_IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached {file_type} file in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize it in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
|
||||
def build_custom_image_generation_user_prompt(
|
||||
query: str, file_type: ChatFileType, files: list[InMemoryChatFile] | None = None
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=CUSTOM_IMG_GENERATION_SUMMARY_PROMPT.format(
|
||||
query=query, file_type=file_type.value
|
||||
).strip(),
|
||||
files=files,
|
||||
)
|
||||
)
|
@ -215,6 +215,7 @@ class TestCustomTool(unittest.TestCase):
|
||||
mock_response = ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
response_type="json",
|
||||
tool_name="getAssistant",
|
||||
tool_result={"id": "789", "name": "Final Assistant"},
|
||||
),
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
ChatSessionSharedStatus,
|
||||
DocumentsResponse,
|
||||
FileDescriptor,
|
||||
ImageGenerationDisplay,
|
||||
FileChatDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
RetrievalType,
|
||||
@ -1281,7 +1281,7 @@ export function ChatPage({
|
||||
query = toolCall.tool_args["query"];
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
aiMessageImages = (packet as FileChatDisplay).file_ids.map(
|
||||
(fileId) => {
|
||||
return {
|
||||
id: fileId,
|
||||
|
@ -136,7 +136,7 @@ export interface DocumentsResponse {
|
||||
rephrased_query: string | null;
|
||||
}
|
||||
|
||||
export interface ImageGenerationDisplay {
|
||||
export interface FileChatDisplay {
|
||||
file_ids: string[];
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ import {
|
||||
ChatSession,
|
||||
DocumentsResponse,
|
||||
FileDescriptor,
|
||||
ImageGenerationDisplay,
|
||||
FileChatDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
RetrievalType,
|
||||
@ -103,7 +103,7 @@ export type PacketType =
|
||||
| BackendMessage
|
||||
| AnswerPiecePacket
|
||||
| DocumentsResponse
|
||||
| ImageGenerationDisplay
|
||||
| FileChatDisplay
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo;
|
||||
|
Loading…
x
Reference in New Issue
Block a user