large number of PR comments addressed

This commit is contained in:
Evan Lohn
2025-01-31 21:20:51 -08:00
parent 118e8afbef
commit 5a95a5c9fd
37 changed files with 244 additions and 212 deletions

View File

@@ -2,7 +2,6 @@ import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import cast
@@ -13,6 +12,7 @@ from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationConfig
@@ -33,6 +33,7 @@ from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
@@ -196,9 +197,9 @@ def _handle_search_tool_response_summary(
for db_search_doc in reference_db_search_docs
]
level, question_nr = None, None
level, question_num = None, None
if isinstance(packet, ExtendedToolResponse):
level, question_nr = packet.level, packet.level_question_nr
level, question_num = packet.level, packet.level_question_num
return (
QADocsResponse(
rephrased_query=response_sumary.rephrased_query,
@@ -209,7 +210,7 @@ def _handle_search_tool_response_summary(
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
level=level,
level_question_nr=question_nr,
level_question_num=question_num,
),
reference_db_search_docs,
dropped_inds,
@@ -310,17 +311,6 @@ ChatPacket = (
ChatPacketStream = Iterator[ChatPacket]
# can't store a DbSearchDoc in a Pydantic BaseModel
@dataclass
class AnswerPostInfo:
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None
message_specific_citations: MessageSpecificCitations | None = None
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@@ -794,18 +784,22 @@ def stream_chat_message_objects(
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
level, level_question_nr = (
(packet.level, packet.level_question_nr)
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
assert level is not None
assert level_question_num is not None
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
@@ -928,13 +922,15 @@ def stream_chat_message_objects(
yield packet
else:
if isinstance(packet, ToolCallFinalResult):
level, level_question_nr = (
(packet.level, packet.level_question_nr)
level, level_question_num = (
(packet.level, packet.level_question_num)
if packet.level is not None
and packet.level_question_nr is not None
and packet.level_question_num is not None
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
info.tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
@@ -971,26 +967,30 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
subq_citations = answer.citations_by_subquestion()
for pair in subq_citations:
level, level_question_nr = pair
info = info_by_subq[(level, level_question_nr)]
for subq_key in subq_citations:
info = info_by_subq[subq_key]
logger.debug("Post-LLM answer processing")
if info.reference_db_search_docs:
info.message_specific_citations = _translate_citations(
citations_list=subq_citations[pair],
citations_list=subq_citations[subq_key],
db_docs=info.reference_db_search_docs,
)
# TODO: AllCitations should contain subq info?
if not answer.is_cancelled():
yield AllCitations(citations=subq_citations[pair])
yield AllCitations(citations=subq_citations[subq_key])
# Saving Gen AI answer and responding with message info
info = (
info_by_subq[BASIC_KEY]
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])]
if BASIC_KEY in info_by_subq
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
else info_by_subq[
SubQuestionKey(
level=AGENT_SEARCH_INITIAL_KEY[0],
question_num=AGENT_SEARCH_INITIAL_KEY[1],
)
]
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
@@ -1025,7 +1025,11 @@ def stream_chat_message_objects(
agent_answers = answer.llm_answer_by_level()
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
info = info_by_subq[
SubQuestionKey(
level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1]
)
]
next_answer_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=prev_message,