mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
large number of PR comments addressed
This commit is contained in:
@@ -2,7 +2,6 @@ import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
@@ -13,6 +12,7 @@ from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationConfig
|
||||
@@ -33,6 +33,7 @@ from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -196,9 +197,9 @@ def _handle_search_tool_response_summary(
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_nr = None, None
|
||||
level, question_num = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_nr = packet.level, packet.level_question_nr
|
||||
level, question_num = packet.level, packet.level_question_num
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_sumary.rephrased_query,
|
||||
@@ -209,7 +210,7 @@ def _handle_search_tool_response_summary(
|
||||
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
level_question_num=question_num,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
@@ -310,17 +311,6 @@ ChatPacket = (
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
# can't store a DbSearchDoc in a Pydantic BaseModel
|
||||
@dataclass
|
||||
class AnswerPostInfo:
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -794,18 +784,22 @@ def stream_chat_message_objects(
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
@@ -928,13 +922,15 @@ def stream_chat_message_objects(
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if packet.level is not None
|
||||
and packet.level_question_nr is not None
|
||||
and packet.level_question_num is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
@@ -971,26 +967,30 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for pair in subq_citations:
|
||||
level, level_question_nr = pair
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
for subq_key in subq_citations:
|
||||
info = info_by_subq[subq_key]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[pair],
|
||||
citations_list=subq_citations[subq_key],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[pair])
|
||||
yield AllCitations(citations=subq_citations[subq_key])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
info = (
|
||||
info_by_subq[BASIC_KEY]
|
||||
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])]
|
||||
if BASIC_KEY in info_by_subq
|
||||
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
|
||||
else info_by_subq[
|
||||
SubQuestionKey(
|
||||
level=AGENT_SEARCH_INITIAL_KEY[0],
|
||||
question_num=AGENT_SEARCH_INITIAL_KEY[1],
|
||||
)
|
||||
]
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
@@ -1025,7 +1025,11 @@ def stream_chat_message_objects(
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(
|
||||
level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1]
|
||||
)
|
||||
]
|
||||
next_answer_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=prev_message,
|
||||
|
Reference in New Issue
Block a user