validated + build-ready

This commit is contained in:
pablodanswer 2024-09-15 15:59:44 -07:00
parent 681175e9c3
commit 659e8cb69e
9 changed files with 140 additions and 112 deletions

View File

@ -676,85 +676,10 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
yielded_message_id_info = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
print("PACKET IS ENINDG")
print(packet)
if isinstance(packet, StreamStopInfo):
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break
@ -786,7 +711,9 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations.citation_map if db_citations else None,
citations=(
db_citations.citation_map if db_citations is not None else None
),
error=None,
tool_call=tool_call,
)
@ -806,11 +733,7 @@ def stream_chat_message_objects(
else gen_ai_response_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=gen_ai_response_message.id,
reserved_assistant_message_id=reserved_message_id,
)
yielded_message_id_info = False
partial_response = partial(
create_new_chat_message,
@ -824,10 +747,94 @@ def stream_chat_message_objects(
commit=False,
)
reference_db_search_docs = None
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
if not yielded_message_id_info:
yield MessageResponseIDInfo(
user_message_id=gen_ai_response_message.id,
reserved_assistant_message_id=reserved_message_id,
)
yielded_message_id_info = True
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
@ -855,6 +862,10 @@ def stream_chat_message_objects(
)
yield AllCitations(citations=answer.citations)
if answer.llm_answer == "":
return
# print(answer.llm_answer)
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,

View File

@ -49,7 +49,10 @@ def default_build_user_message(
else user_query
)
if previous_tool_calls > 0:
user_prompt = f"You have already generated the above but remember the query is: `{user_prompt}`"
user_prompt = (
f"You have already generated the above so do not call a tool if not necessary. "
f"Remember the query is: `{user_prompt}`"
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(

View File

@ -139,7 +139,9 @@ def translate_danswer_msg_to_langchain(
wrapped_content = ""
if msg.message_type == MessageType.ASSISTANT:
try:
parsed_content = json.loads(content)
parsed_content = (
json.loads(content) if isinstance(content, str) else content
)
if (
"name" in parsed_content
and parsed_content["name"] == "run_image_generation"
@ -157,9 +159,9 @@ def translate_danswer_msg_to_langchain(
wrapped_content += f" Image URL: {img['url']}\n\n"
wrapped_content += "[/AI IMAGE GENERATION RESPONSE]"
else:
wrapped_content = content
wrapped_content = str(content)
except json.JSONDecodeError:
wrapped_content = content
wrapped_content = str(content)
return AIMessage(content=wrapped_content)
if msg.message_type == MessageType.USER:

View File

@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
You have just created the most recent attached images in response to the following query: "{query}".
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""

View File

@ -1272,6 +1272,8 @@ export function ChatPage({
}
if (Object.hasOwn(packet, "user_message_id")) {
debugger;
let newParentMessageId = dynamicParentMessage.messageId;
const messageResponseIDInfo = packet as MessageResponseIDInfo;
@ -1325,6 +1327,8 @@ export function ChatPage({
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
retrievalType = RetrievalType.Search;
} else if (Object.hasOwn(packet, "tool_name")) {
debugger;
dynamicAssistantMessage.toolCall = {
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
@ -1405,6 +1409,15 @@ export function ChatPage({
});
};
console.log("\n-----");
console.log(
"dynamicParentMessage",
JSON.stringify(dynamicParentMessage)
);
console.log(
"dynamicAssistantMessage",
JSON.stringify(dynamicAssistantMessage)
);
let { messageMap } = updateFn([
dynamicParentMessage,
dynamicAssistantMessage,
@ -2225,7 +2238,6 @@ export function ChatPage({
query={
messageHistory[i]?.query || undefined
}
personaName={liveAssistant.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
@ -2337,7 +2349,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={message.messageId}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@ -2385,7 +2396,6 @@ export function ChatPage({
alternativeAssistant
}
messageId={null}
personaName={liveAssistant.name}
content={
<div
key={"Generating"}
@ -2405,7 +2415,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={-1}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{loadingError}

View File

@ -143,7 +143,6 @@ export const AIMessage = ({
files,
selectedDocuments,
query,
personaName,
citedDocuments,
toolCall,
isComplete,
@ -175,7 +174,6 @@ export const AIMessage = ({
content: string | JSX.Element;
files?: FileDescriptor[];
query?: string;
personaName?: string;
citedDocuments?: [string, DanswerDocument][] | null;
toolCall?: ToolCallMetadata | null;
isComplete?: boolean;
@ -191,6 +189,7 @@ export const AIMessage = ({
setPopup?: (popupSpec: PopupSpec | null) => void;
}) => {
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const toolCallGenerating = toolCall && !toolCall.tool_result;
const processContent = (content: string | JSX.Element) => {
if (typeof content !== "string") {
@ -214,8 +213,9 @@ export const AIMessage = ({
}
}
if (
isComplete &&
toolCall?.tool_result &&
toolCall.tool_result.tool_name == INTERNET_SEARCH_TOOL_NAME
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
) {
return content + ` [${toolCall.tool_name}]()`;
}
@ -225,6 +225,7 @@ export const AIMessage = ({
const finalContent = processContent(content as string);
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
const settings = useContext(SettingsContext);
@ -413,7 +414,10 @@ export const AIMessage = ({
return (
<Popover
open={isPopoverOpen}
onOpenChange={() => null} // only allow closing from the icon
onOpenChange={
() => null
// setIsPopoverOpen(isPopoverOpen => !isPopoverOpen)
} // only allow closing from the icon
content={
<button
onMouseDown={() => {

View File

@ -203,7 +203,6 @@ export function SearchSummary({
target="_blank"
className="line-clamp-1 text-text-900"
>
{/* <Citation link={doc.link} index={ind + 1} /> */}
<p className="shrink truncate ellipsis break-all ">
{doc.semantic_identifier || doc.document_id}
</p>
@ -239,6 +238,16 @@ export function SearchSummary({
searchingForDisplay
)}
</div>
<button
className="my-auto invisible group-hover:visible transition-all duration-300 rounded"
onClick={toggleDropdown}
>
<ChevronDownIcon
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{handleSearchQueryEdit ? (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
@ -251,17 +260,8 @@ export function SearchSummary({
</button>
</Tooltip>
) : (
"Hi"
<></>
)}
<button
className="my-auto invisible group-hover:visible transition-all duration-300 hover:bg-hover rounded"
onClick={toggleDropdown}
>
<ChevronDownIcon
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
</>
)}
</div>

View File

@ -101,7 +101,6 @@ export function SharedChatDisplay({
messageId={message.messageId}
content={message.message}
files={message.files || []}
personaName={chatSession.persona_name}
citedDocuments={getCitedDocumentsFromMessage(message)}
isComplete
/>

View File

@ -48,7 +48,7 @@ const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
onMouseDown={() => copyToClipboard(prompt, index)}
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
>
{copied != null ? (
{copied == index ? (
<>
<FiCheck className="mr-2" size={16} />
Copied!