large number of PR comments addressed

This commit is contained in:
Evan Lohn
2025-01-31 21:20:51 -08:00
parent 118e8afbef
commit 5a95a5c9fd
37 changed files with 244 additions and 212 deletions

View File

@@ -19,6 +19,7 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
@@ -32,6 +33,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
class Answer:
def __init__(
@@ -164,6 +167,7 @@ class Answer:
and packet.answer_piece
and packet.answer_type == "agent_level_answer"
):
assert packet.level is not None
answer_by_level[packet.level] += packet.answer_piece
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
@@ -178,19 +182,20 @@ class Answer:
return citations
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
citations_by_subquestion: dict[
tuple[int, int], list[CitationInfo]
SubQuestionKey, list[CitationInfo]
] = defaultdict(list)
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
if packet.level_question_nr is not None and packet.level is not None:
if packet.level_question_num is not None and packet.level is not None:
citations_by_subquestion[
(packet.level, packet.level_question_nr)
SubQuestionKey(
level=packet.level, question_num=packet.level_question_num
)
].append(packet)
elif packet.level is None:
citations_by_subquestion[BASIC_KEY].append(packet)
citations_by_subquestion[BASIC_SQ_KEY].append(packet)
return citations_by_subquestion
def is_cancelled(self) -> bool:

View File

@@ -16,6 +16,8 @@ from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.enums import SearchType
from onyx.context.search.models import RetrievalDocs
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -41,16 +43,19 @@ class LlmDoc(BaseModel):
match_highlights: list[str] | None
class SubQuestionIdentifier(BaseModel):
level: int | None = None
level_question_num: int | None = None
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
rephrased_query: str | None = None
predicted_flow: QueryFlow | None
predicted_search: SearchType | None
applied_source_filters: list[DocumentSource] | None
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
level: int | None = None
level_question_nr: int | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -67,13 +72,16 @@ class StreamStopReason(Enum):
FINISHED = "finished"
class StreamStopInfo(BaseModel):
class StreamType(Enum):
SUB_QUESTIONS = "sub_questions"
SUB_ANSWER = "sub_answer"
MAIN_ANSWER = "main_answer"
class StreamStopInfo(SubQuestionIdentifier):
stop_reason: StreamStopReason
stream_type: Literal["", "sub_questions", "sub_answer", "main_answer"] = ""
# used to identify the stream that was stopped for agent search
level: int | None = None
level_question_nr: int | None = None
stream_type: StreamType = StreamType.MAIN_ANSWER
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -114,11 +122,9 @@ class OnyxAnswerPiece(BaseModel):
# An intermediate representation of citations, later translated into
# a mapping of the citation [n] number to SearchDoc
class CitationInfo(BaseModel):
class CitationInfo(SubQuestionIdentifier):
citation_num: int
document_id: str
level: int | None = None
level_question_nr: int | None = None
class AllCitations(BaseModel):
@@ -310,29 +316,22 @@ class PromptConfig(BaseModel):
model_config = ConfigDict(frozen=True)
class SubQueryPiece(BaseModel):
class SubQueryPiece(SubQuestionIdentifier):
sub_query: str
level: int
level_question_nr: int
query_id: int
class AgentAnswerPiece(BaseModel):
class AgentAnswerPiece(SubQuestionIdentifier):
answer_piece: str
level: int
level_question_nr: int
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class SubQuestionPiece(BaseModel):
class SubQuestionPiece(SubQuestionIdentifier):
sub_question: str
level: int
level_question_nr: int
class ExtendedToolResponse(ToolResponse):
level: int
level_question_nr: int
class ExtendedToolResponse(ToolResponse, SubQuestionIdentifier):
pass
class RefinedAnswerImprovement(BaseModel):
@@ -363,3 +362,29 @@ ResponsePart = (
)
AnswerStream = Iterator[AnswerPacket]
class AnswerPostInfo(BaseModel):
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None
message_specific_citations: MessageSpecificCitations | None = None
class Config:
arbitrary_types_allowed = True
class SubQuestionKey(BaseModel):
level: int
question_num: int
def __hash__(self) -> int:
return hash((self.level, self.question_num))
def __eq__(self, other: object) -> bool:
return isinstance(other, SubQuestionKey) and (
self.level,
self.question_num,
) == (other.level, other.question_num)

View File

@@ -2,7 +2,6 @@ import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import cast
@@ -13,6 +12,7 @@ from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationConfig
@@ -33,6 +33,7 @@ from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
@@ -196,9 +197,9 @@ def _handle_search_tool_response_summary(
for db_search_doc in reference_db_search_docs
]
level, question_nr = None, None
level, question_num = None, None
if isinstance(packet, ExtendedToolResponse):
level, question_nr = packet.level, packet.level_question_nr
level, question_num = packet.level, packet.level_question_num
return (
QADocsResponse(
rephrased_query=response_sumary.rephrased_query,
@@ -209,7 +210,7 @@ def _handle_search_tool_response_summary(
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
level=level,
level_question_nr=question_nr,
level_question_num=question_num,
),
reference_db_search_docs,
dropped_inds,
@@ -310,17 +311,6 @@ ChatPacket = (
ChatPacketStream = Iterator[ChatPacket]
# can't store a DbSearchDoc in a Pydantic BaseModel
@dataclass
class AnswerPostInfo:
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None
message_specific_citations: MessageSpecificCitations | None = None
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@@ -794,18 +784,22 @@ def stream_chat_message_objects(
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
level, level_question_nr = (
(packet.level, packet.level_question_nr)
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
assert level is not None
assert level_question_num is not None
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
@@ -928,13 +922,15 @@ def stream_chat_message_objects(
yield packet
else:
if isinstance(packet, ToolCallFinalResult):
level, level_question_nr = (
(packet.level, packet.level_question_nr)
level, level_question_num = (
(packet.level, packet.level_question_num)
if packet.level is not None
and packet.level_question_nr is not None
and packet.level_question_num is not None
else BASIC_KEY
)
info = info_by_subq[(level, level_question_nr)]
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
info.tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
@@ -971,26 +967,30 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
subq_citations = answer.citations_by_subquestion()
for pair in subq_citations:
level, level_question_nr = pair
info = info_by_subq[(level, level_question_nr)]
for subq_key in subq_citations:
info = info_by_subq[subq_key]
logger.debug("Post-LLM answer processing")
if info.reference_db_search_docs:
info.message_specific_citations = _translate_citations(
citations_list=subq_citations[pair],
citations_list=subq_citations[subq_key],
db_docs=info.reference_db_search_docs,
)
# TODO: AllCitations should contain subq info?
if not answer.is_cancelled():
yield AllCitations(citations=subq_citations[pair])
yield AllCitations(citations=subq_citations[subq_key])
# Saving Gen AI answer and responding with message info
info = (
info_by_subq[BASIC_KEY]
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])]
if BASIC_KEY in info_by_subq
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
else info_by_subq[
SubQuestionKey(
level=AGENT_SEARCH_INITIAL_KEY[0],
question_num=AGENT_SEARCH_INITIAL_KEY[1],
)
]
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
@@ -1025,7 +1025,11 @@ def stream_chat_message_objects(
agent_answers = answer.llm_answer_by_level()
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
info = info_by_subq[
SubQuestionKey(
level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1]
)
]
next_answer_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=prev_message,