diff --git a/backend/onyx/agents/agent_search/basic/utils.py b/backend/onyx/agents/agent_search/basic/utils.py index e647f46c3..cd0c63afa 100644 --- a/backend/onyx/agents/agent_search/basic/utils.py +++ b/backend/onyx/agents/agent_search/basic/utils.py @@ -1,10 +1,11 @@ from collections.abc import Iterator from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage +from langgraph.types import StreamWriter +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import LlmDoc from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler @@ -20,6 +21,7 @@ logger = setup_logger() def process_llm_stream( messages: Iterator[BaseMessage], should_stream_answer: bool, + writer: StreamWriter, final_search_results: list[LlmDoc] | None = None, displayed_search_results: list[LlmDoc] | None = None, ) -> AIMessageChunk: @@ -52,9 +54,10 @@ def process_llm_stream( tool_call_chunk += message # type: ignore elif should_stream_answer: for response_part in answer_handler.handle_response_part(message, []): - dispatch_custom_event( + write_custom_event( "basic_response", response_part, + writer, ) logger.info(f"Full answer: {full_answer}") diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py index 64c330f7d..56706570e 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py @@ -2,9 +2,9 @@ from datetime import datetime from typing import Any from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import merge_message_runs from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.initial.generate_individual_sub_answer.states import ( AnswerQuestionState, @@ -22,6 +22,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( get_persona_agent_prompt_expressions, ) from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason @@ -32,7 +33,9 @@ logger = setup_logger() def generate_sub_answer( - state: AnswerQuestionState, config: RunnableConfig + state: AnswerQuestionState, + config: RunnableConfig, + writer: StreamWriter = lambda _: None, ) -> QAGenerationUpdate: now_start = datetime.now() logger.info(f"--------{now_start}--------START ANSWER GENERATION---") @@ -48,7 +51,7 @@ def generate_sub_answer( if len(context_docs) == 0: answer_str = NO_RECOVERED_DOCS - dispatch_custom_event( + write_custom_event( "sub_answers", AgentAnswerPiece( answer_piece=answer_str, @@ -56,6 +59,7 @@ def generate_sub_answer( level_question_nr=question_nr, answer_type="agent_sub_answer", ), + writer, ) else: logger.debug(f"Number of verified retrieval docs: {len(context_docs)}") @@ -81,7 +85,7 @@ def generate_sub_answer( f"Expected content to be a string, but got {type(content)}" ) start_stream_token = datetime.now() - dispatch_custom_event( + write_custom_event( "sub_answers", AgentAnswerPiece( answer_piece=content, @@ -89,6 +93,7 @@ def generate_sub_answer( level_question_nr=question_nr, answer_type="agent_sub_answer", ), + writer, ) end_stream_token = datetime.now() dispatch_timings.append( @@ -112,7 +117,7 @@ def generate_sub_answer( level=level, level_question_nr=question_nr, ) - dispatch_custom_event("stream_finished", stop_event) + write_custom_event("stream_finished", stop_event, writer) now_end = datetime.now() diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py index 52689a1ed..45daec535 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -2,10 +2,10 @@ from datetime import datetime from typing import Any from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_content from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import ( SearchSQState, @@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import ExtendedToolResponse from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS @@ -54,7 +55,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp def generate_initial_answer( - state: SearchSQState, config: RunnableConfig + state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> InitialAnswerUpdate: now_start = datetime.now() @@ -98,7 +99,7 @@ def generate_initial_answer( get_section_relevance=lambda: None, # TODO: add relevance search_tool=agent_a_config.search_tool, ): - dispatch_custom_event( + write_custom_event( "tool_response", ExtendedToolResponse( id=tool_response.id, @@ -106,10 +107,11 @@ def generate_initial_answer( level=0, level_question_nr=0, # 0, 0 is the base question ), + writer, ) if len(relevant_docs) == 0: - dispatch_custom_event( + write_custom_event( "initial_agent_answer", AgentAnswerPiece( answer_piece=UNKNOWN_ANSWER, @@ -117,8 +119,9 @@ def generate_initial_answer( level_question_nr=0, answer_type="agent_level_answer", ), + writer, ) - dispatch_main_answer_stop_info(0) + dispatch_main_answer_stop_info(0, writer) answer = UNKNOWN_ANSWER initial_agent_stats = InitialAgentResultStats( @@ -197,7 +200,7 @@ def generate_initial_answer( ) start_stream_token = datetime.now() - dispatch_custom_event( + write_custom_event( "initial_agent_answer", AgentAnswerPiece( answer_piece=content, @@ -205,6 +208,7 @@ def generate_initial_answer( level_question_nr=0, answer_type="agent_level_answer", ), + writer, ) end_stream_token = datetime.now() dispatch_timings.append( @@ -216,7 +220,7 @@ def generate_initial_answer( f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}" ) - dispatch_main_answer_stop_info(0) + dispatch_main_answer_stop_info(0, writer) response = merge_content(*streamed_tokens) answer = cast(str, response) diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_sub_answers/nodes/decompose_orig_question.py b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_sub_answers/nodes/decompose_orig_question.py index bfb059508..5c23035b9 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_sub_answers/nodes/decompose_orig_question.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_sub_answers/nodes/decompose_orig_question.py @@ -1,10 +1,10 @@ from datetime import datetime from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_content from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import ( SearchSQState, @@ -28,6 +28,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ( INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH, ) from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import SubQuestionPiece @@ -35,7 +36,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION def decompose_orig_question( - state: SearchSQState, config: RunnableConfig + state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> BaseDecompUpdate: now_start = datetime.now() @@ -90,23 +91,26 @@ def decompose_orig_question( msg = [HumanMessage(content=decomposition_prompt)] # Send the initial question as a subquestion with number 0 - dispatch_custom_event( + write_custom_event( "decomp_qs", SubQuestionPiece( sub_question=question, level=0, level_question_nr=0, ), + writer, ) # dispatches custom events for subquestion tokens, adding in subquestion ids. - streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0)) + streamed_tokens = dispatch_separated( + model.stream(msg), dispatch_subquestion(0, writer) + ) stop_event = StreamStopInfo( stop_reason=StreamStopReason.FINISHED, stream_type="sub_questions", level=0, ) - dispatch_custom_event("stream_finished", stop_event) + write_custom_event("stream_finished", stop_event, writer) deomposition_response = merge_content(*streamed_tokens) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/compare_answers.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/compare_answers.py index 89e76df82..d3dc4dedd 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/compare_answers.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/compare_answers.py @@ -1,19 +1,22 @@ from datetime import datetime from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.main.operations import logger from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison from onyx.agents.agent_search.deep_search_a.main.states import MainState from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import RefinedAnswerImprovement -def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison: +def compare_answers( + state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None +) -> AnswerComparison: now_start = datetime.now() agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) @@ -39,11 +42,12 @@ def compare_answers(state: MainState, config: RunnableConfig) -> AnswerCompariso isinstance(resp.content, str) and "yes" in resp.content.lower() ) - dispatch_custom_event( + write_custom_event( "refined_answer_improvement", RefinedAnswerImprovement( refined_answer_improvement=refined_answer_improvement, ), + writer, ) now_end = datetime.now() diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/create_refined_sub_questions.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/create_refined_sub_questions.py index 199a3fd26..55df753b5 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/create_refined_sub_questions.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/create_refined_sub_questions.py @@ -1,10 +1,10 @@ from datetime import datetime from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_content from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.main.models import ( FollowUpSubQuestion, @@ -29,15 +29,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( format_entity_term_extraction, ) from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.tools.models import ToolCallKickoff def create_refined_sub_questions( - state: MainState, config: RunnableConfig + state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> FollowUpSubQuestionsUpdate: """ """ agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - dispatch_custom_event( + write_custom_event( "start_refined_answer_creation", ToolCallKickoff( tool_name="agent_search_1", @@ -46,6 +47,7 @@ def create_refined_sub_questions( "answer": state.initial_answer, }, ), + writer, ) now_start = datetime.now() @@ -90,7 +92,9 @@ def create_refined_sub_questions( # Grader model = agent_a_config.fast_llm - streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1)) + streamed_tokens = dispatch_separated( + model.stream(msg), dispatch_subquestion(1, writer) + ) response = merge_content(*streamed_tokens) if isinstance(response, str): diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py index 6389bad61..1be89ec02 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -2,10 +2,10 @@ from datetime import datetime from typing import Any from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import HumanMessage from langchain_core.messages import merge_content from langchain_core.runnables import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.main.models import ( AgentRefinedMetrics, @@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( ) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import ExtendedToolResponse from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS @@ -52,7 +53,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp def generate_refined_answer( - state: MainState, config: RunnableConfig + state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> RefinedAnswerUpdate: now_start = datetime.now() @@ -102,7 +103,7 @@ def generate_refined_answer( get_section_relevance=lambda: None, # TODO: add relevance search_tool=agent_a_config.search_tool, ): - dispatch_custom_event( + write_custom_event( "tool_response", ExtendedToolResponse( id=tool_response.id, @@ -110,6 +111,7 @@ def generate_refined_answer( level=1, level_question_nr=0, # 0, 0 is the base question ), + writer, ) if len(initial_documents) > 0: @@ -228,7 +230,7 @@ def generate_refined_answer( ) start_stream_token = datetime.now() - dispatch_custom_event( + write_custom_event( "refined_agent_answer", AgentAnswerPiece( answer_piece=content, @@ -236,6 +238,7 @@ def generate_refined_answer( level_question_nr=0, answer_type="agent_level_answer", ), + writer, ) end_stream_token = datetime.now() dispatch_timings.append((end_stream_token - start_stream_token).microseconds) @@ -244,7 +247,7 @@ def generate_refined_answer( logger.info( f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}" ) - dispatch_main_answer_stop_info(1) + dispatch_main_answer_stop_info(1, writer) response = merge_content(*streamed_tokens) answer = cast(str, response) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/operations.py b/backend/onyx/agents/agent_search/deep_search_a/main/operations.py index ef3197e7b..c5b72ca2e 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/operations.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/operations.py @@ -1,7 +1,7 @@ import re from collections.abc import Callable -from langchain_core.callbacks.manager import dispatch_custom_event +from langgraph.types import StreamWriter from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats @@ -9,6 +9,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import QueryResult from onyx.agents.agent_search.shared_graph_utils.models import ( QuestionAnswerResults, ) +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import SubQuestionPiece from onyx.context.search.models import IndexFilters from onyx.tools.models import SearchQueryInfo @@ -38,15 +39,18 @@ def remove_document_citations(text: str) -> str: return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text) -def dispatch_subquestion(level: int) -> Callable[[str, int], None]: +def dispatch_subquestion( + level: int, writer: StreamWriter +) -> Callable[[str, int], None]: def _helper(sub_question_part: str, num: int) -> None: - dispatch_custom_event( + write_custom_event( "decomp_qs", SubQuestionPiece( sub_question=sub_question_part, level=level, level_question_nr=num, ), + writer, ) return _helper diff --git a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/expand_queries.py index 620267af5..7ce748c50 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/expand_queries.py +++ b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/expand_queries.py @@ -4,6 +4,7 @@ from typing import cast from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.operations import ( dispatch_subquery, @@ -26,7 +27,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id def expand_queries( - state: ExpandedRetrievalInput, config: RunnableConfig + state: ExpandedRetrievalInput, + config: RunnableConfig, + writer: StreamWriter = lambda _: None, ) -> QueryExpansionUpdate: # Sometimes we want to expand the original question, sometimes we want to expand a sub-question. # When we are running this node on the original question, no question is explictly passed in. @@ -53,7 +56,7 @@ def expand_queries( ] llm_response_list = dispatch_separated( - llm.stream(prompt=msg), dispatch_subquery(level, question_nr) + llm.stream(prompt=msg), dispatch_subquery(level, question_nr, writer) ) llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content diff --git a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/format_results.py index 198c9b872..f153d9b23 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/nodes/format_results.py @@ -1,7 +1,7 @@ from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import ( ExpandedRetrievalResult, @@ -18,12 +18,15 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import ExtendedToolResponse from onyx.tools.tool_implementations.search.search_tool import yield_search_responses def format_results( - state: ExpandedRetrievalState, config: RunnableConfig + state: ExpandedRetrievalState, + config: RunnableConfig, + writer: StreamWriter = lambda _: None, ) -> ExpandedRetrievalUpdate: level, question_nr = parse_question_id(state.sub_question_id or "0_0") query_infos = [ @@ -55,7 +58,7 @@ def format_results( get_section_relevance=lambda: None, # TODO: add relevance search_tool=agent_a_config.search_tool, ): - dispatch_custom_event( + write_custom_event( "tool_response", ExtendedToolResponse( id=tool_response.id, @@ -63,6 +66,7 @@ def format_results( level=level, level_question_nr=question_nr, ), + writer, ) sub_question_retrieval_stats = calculate_sub_question_retrieval_stats( verified_documents=state.verified_documents, diff --git a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py index 03cddd05a..a6ac9bfa5 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py +++ b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py @@ -2,10 +2,11 @@ from collections import defaultdict from collections.abc import Callable import numpy as np -from langchain_core.callbacks.manager import dispatch_custom_event +from langgraph.types import StreamWriter from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats from onyx.agents.agent_search.shared_graph_utils.models import QueryResult +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import SubQueryPiece from onyx.context.search.models import InferenceSection from onyx.utils.logger import setup_logger @@ -13,9 +14,11 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]: +def dispatch_subquery( + level: int, question_nr: int, writer: StreamWriter +) -> Callable[[str, int], None]: def helper(token: str, num: int) -> None: - dispatch_custom_event( + write_custom_event( "subqueries", SubQueryPiece( sub_query=token, @@ -23,6 +26,7 @@ def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None level_question_nr=question_nr, query_id=num, ), + writer, ) return helper diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py b/backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py index 22f8ceb12..594274fab 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/basic_use_tool_response.py @@ -2,6 +2,7 @@ from typing import cast from langchain_core.messages import AIMessageChunk from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.basic.states import BasicOutput from onyx.agents.agent_search.basic.states import BasicState @@ -19,7 +20,9 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput: +def basic_use_tool_response( + state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None +) -> BasicOutput: agent_config = cast(AgentSearchConfig, config["metadata"]["config"]) structured_response_format = agent_config.structured_response_format llm = agent_config.primary_llm @@ -65,6 +68,7 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO new_tool_call_chunk = process_llm_stream( stream, True, + writer, final_search_results=final_search_results, displayed_search_results=initial_search_results, ) diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py b/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py index 221cfa0fa..821370cef 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/llm_tool_choice.py @@ -3,6 +3,7 @@ from uuid import uuid4 from langchain_core.messages import ToolCall from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.models import AgentSearchConfig @@ -24,7 +25,11 @@ logger = setup_logger() # and a function that handles extracting the necessary fields # from the state and config # TODO: fan-out to multiple tool call nodes? Make this configurable? -def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate: +def llm_tool_choice( + state: ToolChoiceState, + config: RunnableConfig, + writer: StreamWriter = lambda _: None, +) -> ToolChoiceUpdate: """ This node is responsible for calling the LLM to choose a tool. If no tool is chosen, The node MAY emit an answer, depending on whether state["should_stream_answer"] is set. @@ -97,7 +102,9 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic ) tool_message = process_llm_stream( - stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation + stream, + should_stream_answer and not agent_config.skip_gen_ai_answer_generation, + writer, ) # If no tool calls are emitted by the LLM, we should not choose a tool diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py b/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py index 5b35f0faa..b580633f6 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py @@ -1,14 +1,15 @@ from typing import cast -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import AIMessageChunk from langchain_core.messages.tool import ToolCall from langchain_core.runnables.config import RunnableConfig +from langgraph.types import StreamWriter from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.orchestration.states import ToolCallOutput from onyx.agents.agent_search.orchestration.states import ToolCallUpdate from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import AnswerPacket from onyx.tools.message import build_tool_message from onyx.tools.message import ToolCallSummary @@ -19,11 +20,15 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -def emit_packet(packet: AnswerPacket) -> None: - dispatch_custom_event("basic_response", packet) +def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None: + write_custom_event("basic_response", packet, writer) -def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate: +def tool_call( + state: ToolChoiceUpdate, + config: RunnableConfig, + writer: StreamWriter = lambda _: None, +) -> ToolCallUpdate: """Calls the tool specified in the state and updates the state with the result""" cast(AgentSearchConfig, config["metadata"]["config"]) @@ -38,15 +43,15 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate tool_runner = ToolRunner(tool, tool_args) tool_kickoff = tool_runner.kickoff() - emit_packet(tool_kickoff) + emit_packet(tool_kickoff, writer) tool_responses = [] for response in tool_runner.tool_responses(): tool_responses.append(response) - emit_packet(response) + emit_packet(response, writer) tool_final_result = tool_runner.tool_final_result() - emit_packet(tool_final_result) + emit_packet(tool_final_result, writer) tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id) tool_call_summary = ToolCallSummary( diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index ea644ae19..37a9f4e4c 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from datetime import datetime from typing import cast +from langchain_core.runnables.schema import CustomStreamEvent from langchain_core.runnables.schema import StreamEvent from langgraph.graph.state import CompiledStateGraph @@ -144,11 +145,8 @@ def manage_sync_streaming( stream_mode="custom", input=graph_input, config={"metadata": {"config": config, "thread_id": str(message_id)}}, - # debug=True, ): - print(event) - - return [] + yield cast(CustomStreamEvent, event) def run_graph( @@ -159,7 +157,7 @@ def run_graph( config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED config.allow_refinement = ALLOW_REFINEMENT - for event in _manage_async_event_streaming( + for event in manage_sync_streaming( compiled_graph=compiled_graph, config=config, graph_input=input ): if not (parsed_object := _parse_agent_event(event)): diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index b14405c87..a65c1cfe2 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -8,11 +8,13 @@ from datetime import datetime from datetime import timedelta from typing import Any from typing import cast +from typing import Literal +from typing import TypedDict from uuid import UUID -from langchain_core.callbacks.manager import dispatch_custom_event from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage +from langgraph.types import StreamWriter from sqlalchemy.orm import Session from onyx.agents.agent_search.models import AgentSearchConfig @@ -30,6 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import ( HISTORY_CONTEXT_SUMMARY_PROMPT, ) +from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import DocumentPruningConfig @@ -57,6 +60,7 @@ from onyx.tools.tool_implementations.search.search_tool import ( from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary from onyx.tools.tool_implementations.search.search_tool import SearchTool + BaseMessage_Content = str | list[str | dict[str, Any]] @@ -314,13 +318,13 @@ def dispatch_separated( return streamed_tokens -def dispatch_main_answer_stop_info(level: int) -> None: +def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None: stop_event = StreamStopInfo( stop_reason=StreamStopReason.FINISHED, stream_type="main_answer", level=level, ) - dispatch_custom_event("stream_finished", stop_event) + write_custom_event("stream_finished", stop_event, writer) def get_today_prompt() -> str: @@ -368,3 +372,22 @@ def summarize_history( history_context_response_str = "" return history_context_response_str + + +# taken from langchain_core.runnables.schema +# we don't use the one from their library because +# it includes ids they generate +class CustomStreamEvent(TypedDict): + # Overwrite the event field to be more specific. + event: Literal["on_custom_event"] # type: ignore[misc] + """The event type.""" + name: str + """User defined name for the event.""" + data: Any + """The data associated with the event. Free form and can be anything.""" + + +def write_custom_event( + name: str, event: AnswerPacket, stream_writer: StreamWriter +) -> None: + stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))