mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
large number of PR comments addressed
This commit is contained in:
@@ -19,6 +19,7 @@ from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
@@ -32,6 +33,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
@@ -164,6 +167,7 @@ class Answer:
|
||||
and packet.answer_piece
|
||||
and packet.answer_type == "agent_level_answer"
|
||||
):
|
||||
assert packet.level is not None
|
||||
answer_by_level[packet.level] += packet.answer_piece
|
||||
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
|
||||
@@ -178,19 +182,20 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
|
||||
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
|
||||
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[
|
||||
tuple[int, int], list[CitationInfo]
|
||||
SubQuestionKey, list[CitationInfo]
|
||||
] = defaultdict(list)
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
if packet.level_question_nr is not None and packet.level is not None:
|
||||
if packet.level_question_num is not None and packet.level is not None:
|
||||
citations_by_subquestion[
|
||||
(packet.level, packet.level_question_nr)
|
||||
SubQuestionKey(
|
||||
level=packet.level, question_num=packet.level_question_num
|
||||
)
|
||||
].append(packet)
|
||||
elif packet.level is None:
|
||||
citations_by_subquestion[BASIC_KEY].append(packet)
|
||||
citations_by_subquestion[BASIC_SQ_KEY].append(packet)
|
||||
return citations_by_subquestion
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
|
@@ -16,6 +16,8 @@ from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import RetrievalDocs
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -41,16 +43,19 @@ class LlmDoc(BaseModel):
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -67,13 +72,16 @@ class StreamStopReason(Enum):
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
class StreamType(Enum):
|
||||
SUB_QUESTIONS = "sub_questions"
|
||||
SUB_ANSWER = "sub_answer"
|
||||
MAIN_ANSWER = "main_answer"
|
||||
|
||||
|
||||
class StreamStopInfo(SubQuestionIdentifier):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: Literal["", "sub_questions", "sub_answer", "main_answer"] = ""
|
||||
# used to identify the stream that was stopped for agent search
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
stream_type: StreamType = StreamType.MAIN_ANSWER
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -114,11 +122,9 @@ class OnyxAnswerPiece(BaseModel):
|
||||
|
||||
# An intermediate representation of citations, later translated into
|
||||
# a mapping of the citation [n] number to SearchDoc
|
||||
class CitationInfo(BaseModel):
|
||||
class CitationInfo(SubQuestionIdentifier):
|
||||
citation_num: int
|
||||
document_id: str
|
||||
level: int | None = None
|
||||
level_question_nr: int | None = None
|
||||
|
||||
|
||||
class AllCitations(BaseModel):
|
||||
@@ -310,29 +316,22 @@ class PromptConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class SubQueryPiece(BaseModel):
|
||||
class SubQueryPiece(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
query_id: int
|
||||
|
||||
|
||||
class AgentAnswerPiece(BaseModel):
|
||||
class AgentAnswerPiece(SubQuestionIdentifier):
|
||||
answer_piece: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class SubQuestionPiece(BaseModel):
|
||||
class SubQuestionPiece(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
level: int
|
||||
level_question_nr: int
|
||||
|
||||
|
||||
class ExtendedToolResponse(ToolResponse):
|
||||
level: int
|
||||
level_question_nr: int
|
||||
class ExtendedToolResponse(ToolResponse, SubQuestionIdentifier):
|
||||
pass
|
||||
|
||||
|
||||
class RefinedAnswerImprovement(BaseModel):
|
||||
@@ -363,3 +362,29 @@ ResponsePart = (
|
||||
)
|
||||
|
||||
AnswerStream = Iterator[AnswerPacket]
|
||||
|
||||
|
||||
class AnswerPostInfo(BaseModel):
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class SubQuestionKey(BaseModel):
|
||||
level: int
|
||||
question_num: int
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.level, self.question_num))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, SubQuestionKey) and (
|
||||
self.level,
|
||||
self.question_num,
|
||||
) == (other.level, other.question_num)
|
||||
|
@@ -2,7 +2,6 @@ import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
@@ -13,6 +12,7 @@ from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationConfig
|
||||
@@ -33,6 +33,7 @@ from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -196,9 +197,9 @@ def _handle_search_tool_response_summary(
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
level, question_nr = None, None
|
||||
level, question_num = None, None
|
||||
if isinstance(packet, ExtendedToolResponse):
|
||||
level, question_nr = packet.level, packet.level_question_nr
|
||||
level, question_num = packet.level, packet.level_question_num
|
||||
return (
|
||||
QADocsResponse(
|
||||
rephrased_query=response_sumary.rephrased_query,
|
||||
@@ -209,7 +210,7 @@ def _handle_search_tool_response_summary(
|
||||
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
level_question_num=question_num,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
@@ -310,17 +311,6 @@ ChatPacket = (
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
# can't store a DbSearchDoc in a Pydantic BaseModel
|
||||
@dataclass
|
||||
class AnswerPostInfo:
|
||||
ai_message_files: list[FileDescriptor]
|
||||
qa_docs_response: QADocsResponse | None = None
|
||||
reference_db_search_docs: list[DbSearchDoc] | None = None
|
||||
dropped_indices: list[int] | None = None
|
||||
tool_result: ToolCallFinalResult | None = None
|
||||
message_specific_citations: MessageSpecificCitations | None = None
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -794,18 +784,22 @@ def stream_chat_message_objects(
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[tuple[int, int], AnswerPostInfo] = defaultdict(
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
@@ -928,13 +922,15 @@ def stream_chat_message_objects(
|
||||
yield packet
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
level, level_question_nr = (
|
||||
(packet.level, packet.level_question_nr)
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if packet.level is not None
|
||||
and packet.level_question_nr is not None
|
||||
and packet.level_question_num is not None
|
||||
else BASIC_KEY
|
||||
)
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
info.tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
@@ -971,26 +967,30 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
subq_citations = answer.citations_by_subquestion()
|
||||
for pair in subq_citations:
|
||||
level, level_question_nr = pair
|
||||
info = info_by_subq[(level, level_question_nr)]
|
||||
for subq_key in subq_citations:
|
||||
info = info_by_subq[subq_key]
|
||||
logger.debug("Post-LLM answer processing")
|
||||
if info.reference_db_search_docs:
|
||||
info.message_specific_citations = _translate_citations(
|
||||
citations_list=subq_citations[pair],
|
||||
citations_list=subq_citations[subq_key],
|
||||
db_docs=info.reference_db_search_docs,
|
||||
)
|
||||
|
||||
# TODO: AllCitations should contain subq info?
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=subq_citations[pair])
|
||||
yield AllCitations(citations=subq_citations[subq_key])
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
|
||||
info = (
|
||||
info_by_subq[BASIC_KEY]
|
||||
info_by_subq[SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])]
|
||||
if BASIC_KEY in info_by_subq
|
||||
else info_by_subq[AGENT_SEARCH_INITIAL_KEY]
|
||||
else info_by_subq[
|
||||
SubQuestionKey(
|
||||
level=AGENT_SEARCH_INITIAL_KEY[0],
|
||||
question_num=AGENT_SEARCH_INITIAL_KEY[1],
|
||||
)
|
||||
]
|
||||
)
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
@@ -1025,7 +1025,11 @@ def stream_chat_message_objects(
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[(next_level, AGENT_SEARCH_INITIAL_KEY[1])]
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(
|
||||
level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1]
|
||||
)
|
||||
]
|
||||
next_answer_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=prev_message,
|
||||
|
Reference in New Issue
Block a user