From 659e8cb69edeb371d96efd73de316621ff220b8d Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sun, 15 Sep 2024 15:59:44 -0700 Subject: [PATCH] validated + build-ready --- backend/danswer/chat/process_message.py | 183 ++++++++++-------- .../danswer/llm/answering/prompts/build.py | 5 +- backend/danswer/llm/utils.py | 8 +- backend/danswer/tools/images/prompt.py | 2 +- web/src/app/chat/ChatPage.tsx | 17 +- web/src/app/chat/message/Messages.tsx | 12 +- web/src/app/chat/message/SearchSummary.tsx | 22 +-- .../shared/[chatId]/SharedChatDisplay.tsx | 1 - .../app/chat/tools/ImagePromptCitaiton.tsx | 2 +- 9 files changed, 140 insertions(+), 112 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 3d45778c8..0c88bae9b 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -676,85 +676,10 @@ def stream_chat_message_objects( ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images dropped_indices = None tool_result = None + yielded_message_id_info = True for packet in answer.processed_streamed_output: - if isinstance(packet, ToolResponse): - if packet.id == SEARCH_RESPONSE_SUMMARY_ID: - ( - qa_docs_response, - reference_db_search_docs, - dropped_indices, - ) = _handle_search_tool_response_summary( - packet=packet, - db_session=db_session, - selected_search_docs=selected_db_search_docs, - # Deduping happens at the last step to avoid harming quality by dropping content early on - dedupe_docs=retrieval_options.dedupe_docs - if retrieval_options - else False, - ) - yield qa_docs_response - elif packet.id == SECTION_RELEVANCE_LIST_ID: - relevance_sections = packet.response - - if reference_db_search_docs is not None: - llm_indices = relevant_sections_to_indices( - relevance_sections=relevance_sections, - items=[ - translate_db_search_doc_to_server_search_doc(doc) - for doc in reference_db_search_docs - ], - ) - - if dropped_indices: - llm_indices = drop_llm_indices( - llm_indices=llm_indices, - search_docs=reference_db_search_docs, - dropped_indices=dropped_indices, - ) - - yield LLMRelevanceFilterResponse( - llm_selected_doc_indices=llm_indices - ) - - elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: - yield FinalUsedContextDocsResponse( - final_context_docs=packet.response - ) - elif packet.id == IMAGE_GENERATION_RESPONSE_ID: - img_generation_response = cast( - list[ImageGenerationResponse], packet.response - ) - - file_ids = save_files_from_urls( - [img.url for img in img_generation_response] - ) - ai_message_files = [ - FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) - for file_id in file_ids - ] - yield ImageGenerationDisplay( - file_ids=[str(file_id) for file_id in file_ids] - ) - elif packet.id == INTERNET_SEARCH_RESPONSE_ID: - ( - qa_docs_response, - reference_db_search_docs, - ) = _handle_internet_search_tool_response_summary( - packet=packet, - db_session=db_session, - ) - yield qa_docs_response - elif packet.id == CUSTOM_TOOL_RESPONSE_ID: - custom_tool_response = cast(CustomToolCallSummary, packet.response) - yield CustomToolResponse( - response=custom_tool_response.tool_result, - tool_name=custom_tool_response.tool_name, - ) - elif isinstance(packet, StreamStopInfo): - print("PACKET IS ENINDG") - print(packet) - + if isinstance(packet, StreamStopInfo): if packet.stop_reason is not StreamStopReason.NEW_RESPONSE: break @@ -786,7 +711,9 @@ def stream_chat_message_objects( reference_docs=reference_db_search_docs, files=ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), - citations=db_citations.citation_map if db_citations else None, + citations=( + db_citations.citation_map if db_citations is not None else None + ), error=None, tool_call=tool_call, ) @@ -806,11 +733,7 @@ def stream_chat_message_objects( else gen_ai_response_message.id, message_type=MessageType.ASSISTANT, ) - - yield MessageResponseIDInfo( - user_message_id=gen_ai_response_message.id, - reserved_assistant_message_id=reserved_message_id, - ) + yielded_message_id_info = False partial_response = partial( create_new_chat_message, @@ -824,10 +747,94 @@ def stream_chat_message_objects( commit=False, ) reference_db_search_docs = None + else: - if isinstance(packet, ToolCallFinalResult): - tool_result = packet - yield cast(ChatPacket, packet) + if not yielded_message_id_info: + yield MessageResponseIDInfo( + user_message_id=gen_ai_response_message.id, + reserved_assistant_message_id=reserved_message_id, + ) + yielded_message_id_info = True + + if isinstance(packet, ToolResponse): + if packet.id == SEARCH_RESPONSE_SUMMARY_ID: + ( + qa_docs_response, + reference_db_search_docs, + dropped_indices, + ) = _handle_search_tool_response_summary( + packet=packet, + db_session=db_session, + selected_search_docs=selected_db_search_docs, + # Deduping happens at the last step to avoid harming quality by dropping content early on + dedupe_docs=retrieval_options.dedupe_docs + if retrieval_options + else False, + ) + yield qa_docs_response + elif packet.id == SECTION_RELEVANCE_LIST_ID: + relevance_sections = packet.response + + if reference_db_search_docs is not None: + llm_indices = relevant_sections_to_indices( + relevance_sections=relevance_sections, + items=[ + translate_db_search_doc_to_server_search_doc(doc) + for doc in reference_db_search_docs + ], + ) + + if dropped_indices: + llm_indices = drop_llm_indices( + llm_indices=llm_indices, + search_docs=reference_db_search_docs, + dropped_indices=dropped_indices, + ) + + yield LLMRelevanceFilterResponse( + llm_selected_doc_indices=llm_indices + ) + + elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: + yield FinalUsedContextDocsResponse( + final_context_docs=packet.response + ) + elif packet.id == IMAGE_GENERATION_RESPONSE_ID: + img_generation_response = cast( + list[ImageGenerationResponse], packet.response + ) + + file_ids = save_files_from_urls( + [img.url for img in img_generation_response] + ) + ai_message_files = [ + FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) + for file_id in file_ids + ] + yield ImageGenerationDisplay( + file_ids=[str(file_id) for file_id in file_ids] + ) + elif packet.id == INTERNET_SEARCH_RESPONSE_ID: + ( + qa_docs_response, + reference_db_search_docs, + ) = _handle_internet_search_tool_response_summary( + packet=packet, + db_session=db_session, + ) + yield qa_docs_response + elif packet.id == CUSTOM_TOOL_RESPONSE_ID: + custom_tool_response = cast( + CustomToolCallSummary, packet.response + ) + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) + else: + if isinstance(packet, ToolCallFinalResult): + tool_result = packet + yield cast(ChatPacket, packet) logger.debug("Reached end of stream") except Exception as e: @@ -855,6 +862,10 @@ def stream_chat_message_objects( ) yield AllCitations(citations=answer.citations) + if answer.llm_answer == "": + return + + # print(answer.llm_answer) gen_ai_response_message = partial_response( reserved_message_id=reserved_message_id, message=answer.llm_answer, diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index ecf8eed56..a1e945f18 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -49,7 +49,10 @@ def default_build_user_message( else user_query ) if previous_tool_calls > 0: - user_prompt = f"You have already generated the above but remember the query is: `{user_prompt}`" + user_prompt = ( + f"You have already generated the above so do not call a tool if not necessary. " + f"Remember the query is: `{user_prompt}`" + ) user_prompt = user_prompt.strip() user_msg = HumanMessage( diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index a981219e4..7383e5d72 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -139,7 +139,9 @@ def translate_danswer_msg_to_langchain( wrapped_content = "" if msg.message_type == MessageType.ASSISTANT: try: - parsed_content = json.loads(content) + parsed_content = ( + json.loads(content) if isinstance(content, str) else content + ) if ( "name" in parsed_content and parsed_content["name"] == "run_image_generation" @@ -157,9 +159,9 @@ def translate_danswer_msg_to_langchain( wrapped_content += f" Image URL: {img['url']}\n\n" wrapped_content += "[/AI IMAGE GENERATION RESPONSE]" else: - wrapped_content = content + wrapped_content = str(content) except json.JSONDecodeError: - wrapped_content = content + wrapped_content = str(content) return AIMessage(content=wrapped_content) if msg.message_type == MessageType.USER: diff --git a/backend/danswer/tools/images/prompt.py b/backend/danswer/tools/images/prompt.py index bb729bfcd..6f72554ab 100644 --- a/backend/danswer/tools/images/prompt.py +++ b/backend/danswer/tools/images/prompt.py @@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs IMG_GENERATION_SUMMARY_PROMPT = """ -You have just created the attached images in response to the following query: "{query}". +You have just created the most recent attached images in response to the following query: "{query}". Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists. """ diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 9ac4b99d1..a607ea60f 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -1272,6 +1272,8 @@ export function ChatPage({ } if (Object.hasOwn(packet, "user_message_id")) { + debugger; + let newParentMessageId = dynamicParentMessage.messageId; const messageResponseIDInfo = packet as MessageResponseIDInfo; @@ -1325,6 +1327,8 @@ export function ChatPage({ dynamicAssistantMessage.retrievalType = RetrievalType.Search; retrievalType = RetrievalType.Search; } else if (Object.hasOwn(packet, "tool_name")) { + debugger; + dynamicAssistantMessage.toolCall = { tool_name: (packet as ToolCallMetadata).tool_name, tool_args: (packet as ToolCallMetadata).tool_args, @@ -1405,6 +1409,15 @@ export function ChatPage({ }); }; + console.log("\n-----"); + console.log( + "dynamicParentMessage", + JSON.stringify(dynamicParentMessage) + ); + console.log( + "dynamicAssistantMessage", + JSON.stringify(dynamicAssistantMessage) + ); let { messageMap } = updateFn([ dynamicParentMessage, dynamicAssistantMessage, @@ -2225,7 +2238,6 @@ export function ChatPage({ query={ messageHistory[i]?.query || undefined } - personaName={liveAssistant.name} citedDocuments={getCitedDocumentsFromMessage( message )} @@ -2337,7 +2349,6 @@ export function ChatPage({ {message.message} @@ -2385,7 +2396,6 @@ export function ChatPage({ alternativeAssistant } messageId={null} - personaName={liveAssistant.name} content={
{loadingError} diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 03e418fc8..8ee7c36a6 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -143,7 +143,6 @@ export const AIMessage = ({ files, selectedDocuments, query, - personaName, citedDocuments, toolCall, isComplete, @@ -175,7 +174,6 @@ export const AIMessage = ({ content: string | JSX.Element; files?: FileDescriptor[]; query?: string; - personaName?: string; citedDocuments?: [string, DanswerDocument][] | null; toolCall?: ToolCallMetadata | null; isComplete?: boolean; @@ -191,6 +189,7 @@ export const AIMessage = ({ setPopup?: (popupSpec: PopupSpec | null) => void; }) => { const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const toolCallGenerating = toolCall && !toolCall.tool_result; const processContent = (content: string | JSX.Element) => { if (typeof content !== "string") { @@ -214,8 +213,9 @@ export const AIMessage = ({ } } if ( + isComplete && toolCall?.tool_result && - toolCall.tool_result.tool_name == INTERNET_SEARCH_TOOL_NAME + toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME ) { return content + ` [${toolCall.tool_name}]()`; } @@ -225,6 +225,7 @@ export const AIMessage = ({ const finalContent = processContent(content as string); const [isRegenerateHovered, setIsRegenerateHovered] = useState(false); + const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking(); const settings = useContext(SettingsContext); @@ -413,7 +414,10 @@ export const AIMessage = ({ return ( null} // only allow closing from the icon + onOpenChange={ + () => null + // setIsPopoverOpen(isPopoverOpen => !isPopoverOpen) + } // only allow closing from the icon content={
+ + + {handleSearchQueryEdit ? ( )} diff --git a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx index 489163aa3..205d99b77 100644 --- a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx +++ b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx @@ -101,7 +101,6 @@ export function SharedChatDisplay({ messageId={message.messageId} content={message.message} files={message.files || []} - personaName={chatSession.persona_name} citedDocuments={getCitedDocumentsFromMessage(message)} isComplete /> diff --git a/web/src/app/chat/tools/ImagePromptCitaiton.tsx b/web/src/app/chat/tools/ImagePromptCitaiton.tsx index 5f7cb42c2..9c7ce9535 100644 --- a/web/src/app/chat/tools/ImagePromptCitaiton.tsx +++ b/web/src/app/chat/tools/ImagePromptCitaiton.tsx @@ -48,7 +48,7 @@ const DualPromptDisplay = forwardRef( onMouseDown={() => copyToClipboard(prompt, index)} className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200" > - {copied != null ? ( + {copied == index ? ( <> Copied!