Fix image generation follow up q

This commit is contained in:
Weves 2024-07-09 11:26:13 -07:00 committed by Chris Weaver
parent 3aa0e0124b
commit cbfbe4e5d8
6 changed files with 25 additions and 4 deletions

View File

@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
@ -43,6 +44,7 @@ def create_chat_chain(
user_id=None,
db_session=db_session,
skip_permission_check=True,
prefetch_tool_calls=prefetch_tool_calls,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}

View File

@ -826,6 +826,8 @@ class ChatMessage(Base):
secondary="chat_message__search_doc",
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",

View File

@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
@classmethod
def from_chat_message(
@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
def to_langchain_msg(self) -> BaseMessage:

View File

@ -37,9 +37,13 @@ logger = setup_logger()
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now
files = [] if isinstance(msg, ChatMessage) else msg.files
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# be attached to AI messages, so we must remove them
if isinstance(msg, PreviousMessage) and msg.message_type != MessageType.ASSISTANT:
files = msg.files
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:

View File

@ -156,7 +156,9 @@ class ImageGenerationTool(Tool):
for image_generation in image_generations
]
),
img_urls=[image_generation.url for image_generation in image_generations],
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
# Tool messages to contain images
# img_urls=[image_generation.url for image_generation in image_generations],
)
def _generate_image(self, prompt: str) -> ImageGenerationResponse:

View File

@ -10,7 +10,8 @@ Can you please summarize them in a sentence or two?
"""
TOOL_CALLING_PROMPT = """
Can you please summarize the two images you generate in a sentence or two?
Can you please summarize the two images you just generated in a sentence or two? Do not use a
numbered list.
"""