From 2d8486bac42129400478b68382a590650d3c9c5b Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Fri, 24 Jan 2025 15:03:10 -0800 Subject: [PATCH] stop infos when done streaming answers --- .../main/nodes/generate_initial_answer.py | 5 +++++ .../main/nodes/generate_refined_answer.py | 4 ++++ .../agents/agent_search/shared_graph_utils/utils.py | 12 ++++++++++++ backend/onyx/chat/models.py | 2 +- 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py index 0eb033240..824475d03 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py @@ -43,6 +43,9 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ( SUB_QUESTION_ANSWER_TEMPLATE, ) from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER +from onyx.agents.agent_search.shared_graph_utils.utils import ( + dispatch_main_answer_stop_info, +) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt @@ -84,6 +87,7 @@ def generate_initial_answer( answer_type="agent_level_answer", ), ) + dispatch_main_answer_stop_info(0) answer = UNKNOWN_ANSWER initial_agent_stats = InitialAgentResultStats( @@ -209,6 +213,7 @@ def generate_initial_answer( ) streamed_tokens.append(content) + dispatch_main_answer_stop_info(0) response = merge_content(*streamed_tokens) answer = cast(str, response) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py index 6df691b97..a121994fc 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -40,6 +40,9 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ( SUB_QUESTION_ANSWER_TEMPLATE, ) from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER +from onyx.agents.agent_search.shared_graph_utils.utils import ( + dispatch_main_answer_stop_info, +) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt @@ -216,6 +219,7 @@ def generate_refined_answer( ) streamed_tokens.append(content) + dispatch_main_answer_stop_info(1) response = merge_content(*streamed_tokens) answer = cast(str, response) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 084cec8ea..4bb3a1547 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -10,6 +10,7 @@ from typing import Any from typing import cast from uuid import UUID +from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage from sqlalchemy.orm import Session @@ -23,6 +24,8 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import PromptConfig +from onyx.chat.models import StreamStopInfo +from onyx.chat.models import StreamStopReason from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT @@ -285,5 +288,14 @@ def dispatch_separated( return streamed_tokens +def dispatch_main_answer_stop_info(level: int) -> None: + stop_event = StreamStopInfo( + stop_reason=StreamStopReason.FINISHED, + stream_type="main_answer", + level=level, + ) + dispatch_custom_event("stream_finished", stop_event) + + def get_today_prompt() -> str: return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y")) diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 55fb5bfc0..b71a5c5ee 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -70,7 +70,7 @@ class StreamStopReason(Enum): class StreamStopInfo(BaseModel): stop_reason: StreamStopReason - stream_type: Literal["", "sub_questions", "sub_answer"] = "" + stream_type: Literal["", "sub_questions", "sub_answer", "main_answer"] = "" # used to identify the stream that was stopped for agent search level: int | None = None level_question_nr: int | None = None