mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Fix image generation follow up q
This commit is contained in:
parent
3aa0e0124b
commit
cbfbe4e5d8
@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@ -43,6 +44,7 @@ def create_chat_chain(
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
|
@ -826,6 +826,8 @@ class ChatMessage(Base):
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
|
@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
@ -37,9 +37,13 @@ logger = setup_logger()
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now
|
||||
files = [] if isinstance(msg, ChatMessage) else msg.files
|
||||
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# be attached to AI messages, so we must remove them
|
||||
if isinstance(msg, PreviousMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
|
@ -156,7 +156,9 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
img_urls=[image_generation.url for image_generation in image_generations],
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(self, prompt: str) -> ImageGenerationResponse:
|
||||
|
@ -10,7 +10,8 @@ Can you please summarize them in a sentence or two?
|
||||
"""
|
||||
|
||||
TOOL_CALLING_PROMPT = """
|
||||
Can you please summarize the two images you generate in a sentence or two?
|
||||
Can you please summarize the two images you just generated in a sentence or two? Do not use a
|
||||
numbered list.
|
||||
"""
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user