From cbfbe4e5d87c6f96f8ac8f2a2f37f0bcffa9fde7 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 9 Jul 2024 11:26:13 -0700 Subject: [PATCH] Fix image generation follow up q --- backend/danswer/chat/chat_utils.py | 2 ++ backend/danswer/db/models.py | 2 ++ backend/danswer/llm/answering/models.py | 10 ++++++++++ backend/danswer/llm/utils.py | 8 ++++++-- backend/danswer/tools/images/image_generation_tool.py | 4 +++- backend/danswer/tools/images/prompt.py | 3 ++- 6 files changed, 25 insertions(+), 4 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 7e64a118e..a6c25c2b0 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo def create_chat_chain( chat_session_id: int, db_session: Session, + prefetch_tool_calls: bool = True, ) -> tuple[ChatMessage, list[ChatMessage]]: """Build the linear chain of messages without including the root message""" mainline_messages: list[ChatMessage] = [] @@ -43,6 +44,7 @@ def create_chat_chain( user_id=None, db_session=db_session, skip_permission_check=True, + prefetch_tool_calls=prefetch_tool_calls, ) id_to_msg = {msg.id: msg for msg in all_chat_messages} diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 5fcdc303b..ab3099f40 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -826,6 +826,8 @@ class ChatMessage(Base): secondary="chat_message__search_doc", back_populates="chat_messages", ) + # NOTE: Should always be attached to the `assistant` message. + # represents the tool calls used to generate this message tool_calls: Mapped[list["ToolCall"]] = relationship( "ToolCall", back_populates="message", diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 94ca91703..432ea7338 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType from danswer.file_store.models import InMemoryChatFile from danswer.llm.override_models import PromptOverride from danswer.llm.utils import build_content_with_imgs +from danswer.tools.models import ToolCallFinalResult if TYPE_CHECKING: from danswer.db.models import ChatMessage @@ -32,6 +33,7 @@ class PreviousMessage(BaseModel): token_count: int message_type: MessageType files: list[InMemoryChatFile] + tool_calls: list[ToolCallFinalResult] @classmethod def from_chat_message( @@ -49,6 +51,14 @@ class PreviousMessage(BaseModel): for file in available_files if str(file.file_id) in message_file_ids ], + tool_calls=[ + ToolCallFinalResult( + tool_name=tool_call.tool_name, + tool_args=tool_call.tool_arguments, + tool_result=tool_call.tool_result, + ) + for tool_call in chat_message.tool_calls + ], ) def to_langchain_msg(self) -> BaseMessage: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 7b8d2ceae..4be8e1464 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -37,9 +37,13 @@ logger = setup_logger() def translate_danswer_msg_to_langchain( msg: Union[ChatMessage, "PreviousMessage"], ) -> BaseMessage: + files: list[InMemoryChatFile] = [] + # If the message is a `ChatMessage`, it doesn't have the downloaded files - # attached. Just ignore them for now - files = [] if isinstance(msg, ChatMessage) else msg.files + # attached. Just ignore them for now. Also, OpenAI doesn't allow files to + # be attached to AI messages, so we must remove them + if isinstance(msg, PreviousMessage) and msg.message_type != MessageType.ASSISTANT: + files = msg.files content = build_content_with_imgs(msg.message, files) if msg.message_type == MessageType.SYSTEM: diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index ed145a55c..3b798c43d 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -156,7 +156,9 @@ class ImageGenerationTool(Tool): for image_generation in image_generations ] ), - img_urls=[image_generation.url for image_generation in image_generations], + # NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow + # Tool messages to contain images + # img_urls=[image_generation.url for image_generation in image_generations], ) def _generate_image(self, prompt: str) -> ImageGenerationResponse: diff --git a/backend/danswer/tools/images/prompt.py b/backend/danswer/tools/images/prompt.py index dee28b49c..7a501bf42 100644 --- a/backend/danswer/tools/images/prompt.py +++ b/backend/danswer/tools/images/prompt.py @@ -10,7 +10,8 @@ Can you please summarize them in a sentence or two? """ TOOL_CALLING_PROMPT = """ -Can you please summarize the two images you generate in a sentence or two? +Can you please summarize the two images you just generated in a sentence or two? Do not use a + numbered list. """