sync streaming impl

This commit is contained in:
Evan Lohn 2025-01-30 16:58:46 -08:00
parent 4a0b2a6c09
commit a340529de3
16 changed files with 139 additions and 60 deletions

View File

@ -1,10 +1,11 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
@ -20,6 +21,7 @@ logger = setup_logger()
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
@ -52,9 +54,10 @@ def process_llm_stream(
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
dispatch_custom_event(
write_custom_event(
"basic_response",
response_part,
writer,
)
logger.info(f"Full answer: {full_answer}")

View File

@ -2,9 +2,9 @@ from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
@ -22,6 +22,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@ -32,7 +33,9 @@ logger = setup_logger()
def generate_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
state: AnswerQuestionState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QAGenerationUpdate:
now_start = datetime.now()
logger.info(f"--------{now_start}--------START ANSWER GENERATION---")
@ -48,7 +51,7 @@ def generate_sub_answer(
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
dispatch_custom_event(
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
@ -56,6 +59,7 @@ def generate_sub_answer(
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
writer,
)
else:
logger.debug(f"Number of verified retrieval docs: {len(context_docs)}")
@ -81,7 +85,7 @@ def generate_sub_answer(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
dispatch_custom_event(
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
@ -89,6 +93,7 @@ def generate_sub_answer(
level_question_nr=question_nr,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
@ -112,7 +117,7 @@ def generate_sub_answer(
level=level,
level_question_nr=question_nr,
)
dispatch_custom_event("stream_finished", stop_event)
write_custom_event("stream_finished", stop_event, writer)
now_end = datetime.now()

View File

@ -2,10 +2,10 @@ from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import (
SearchSQState,
@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
@ -54,7 +55,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
def generate_initial_answer(
state: SearchSQState, config: RunnableConfig
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialAnswerUpdate:
now_start = datetime.now()
@ -98,7 +99,7 @@ def generate_initial_answer(
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
@ -106,10 +107,11 @@ def generate_initial_answer(
level=0,
level_question_nr=0, # 0, 0 is the base question
),
writer,
)
if len(relevant_docs) == 0:
dispatch_custom_event(
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
@ -117,8 +119,9 @@ def generate_initial_answer(
level_question_nr=0,
answer_type="agent_level_answer",
),
writer,
)
dispatch_main_answer_stop_info(0)
dispatch_main_answer_stop_info(0, writer)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
@ -197,7 +200,7 @@ def generate_initial_answer(
)
start_stream_token = datetime.now()
dispatch_custom_event(
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
@ -205,6 +208,7 @@ def generate_initial_answer(
level_question_nr=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
@ -216,7 +220,7 @@ def generate_initial_answer(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0)
dispatch_main_answer_stop_info(0, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)

View File

@ -1,10 +1,10 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import (
SearchSQState,
@ -28,6 +28,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionPiece
@ -35,7 +36,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
def decompose_orig_question(
state: SearchSQState, config: RunnableConfig
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BaseDecompUpdate:
now_start = datetime.now()
@ -90,23 +91,26 @@ def decompose_orig_question(
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
dispatch_custom_event(
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_nr=0,
),
writer,
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(0, writer)
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="sub_questions",
level=0,
)
dispatch_custom_event("stream_finished", stop_event)
write_custom_event("stream_finished", stop_event, writer)
deomposition_response = merge_content(*streamed_tokens)

View File

@ -1,19 +1,22 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison:
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnswerComparison:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
@ -39,11 +42,12 @@ def compare_answers(state: MainState, config: RunnableConfig) -> AnswerCompariso
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
dispatch_custom_event(
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
writer,
)
now_end = datetime.now()

View File

@ -1,10 +1,10 @@
from datetime import datetime
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.main.models import (
FollowUpSubQuestion,
@ -29,15 +29,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.tools.models import ToolCallKickoff
def create_refined_sub_questions(
state: MainState, config: RunnableConfig
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> FollowUpSubQuestionsUpdate:
""" """
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
dispatch_custom_event(
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
@ -46,6 +47,7 @@ def create_refined_sub_questions(
"answer": state.initial_answer,
},
),
writer,
)
now_start = datetime.now()
@ -90,7 +92,9 @@ def create_refined_sub_questions(
# Grader
model = agent_a_config.fast_llm
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(1, writer)
)
response = merge_content(*streamed_tokens)
if isinstance(response, str):

View File

@ -2,10 +2,10 @@ from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.main.models import (
AgentRefinedMetrics,
@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
@ -52,7 +53,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
def generate_refined_answer(
state: MainState, config: RunnableConfig
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
now_start = datetime.now()
@ -102,7 +103,7 @@ def generate_refined_answer(
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
@ -110,6 +111,7 @@ def generate_refined_answer(
level=1,
level_question_nr=0, # 0, 0 is the base question
),
writer,
)
if len(initial_documents) > 0:
@ -228,7 +230,7 @@ def generate_refined_answer(
)
start_stream_token = datetime.now()
dispatch_custom_event(
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
@ -236,6 +238,7 @@ def generate_refined_answer(
level_question_nr=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
@ -244,7 +247,7 @@ def generate_refined_answer(
logger.info(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1)
dispatch_main_answer_stop_info(1, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)

View File

@ -1,7 +1,7 @@
import re
from collections.abc import Callable
from langchain_core.callbacks.manager import dispatch_custom_event
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
@ -9,6 +9,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.models import (
QuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
@ -38,15 +39,18 @@ def remove_document_citations(text: str) -> str:
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
def dispatch_subquestion(
level: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, num: int) -> None:
dispatch_custom_event(
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_nr=num,
),
writer,
)
return _helper

View File

@ -4,6 +4,7 @@ from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.operations import (
dispatch_subquery,
@ -26,7 +27,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
def expand_queries(
state: ExpandedRetrievalInput, config: RunnableConfig
state: ExpandedRetrievalInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QueryExpansionUpdate:
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
@ -53,7 +56,7 @@ def expand_queries(
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
llm.stream(prompt=msg), dispatch_subquery(level, question_nr, writer)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content

View File

@ -1,7 +1,7 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
ExpandedRetrievalResult,
@ -18,12 +18,15 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState, config: RunnableConfig
state: ExpandedRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_infos = [
@ -55,7 +58,7 @@ def format_results(
get_section_relevance=lambda: None, # TODO: add relevance
search_tool=agent_a_config.search_tool,
):
dispatch_custom_event(
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
@ -63,6 +66,7 @@ def format_results(
level=level,
level_question_nr=question_nr,
),
writer,
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,

View File

@ -2,10 +2,11 @@ from collections import defaultdict
from collections.abc import Callable
import numpy as np
from langchain_core.callbacks.manager import dispatch_custom_event
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
@ -13,9 +14,11 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
def dispatch_subquery(
level: int, question_nr: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
dispatch_custom_event(
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
@ -23,6 +26,7 @@ def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None
level_question_nr=question_nr,
query_id=num,
),
writer,
)
return helper

View File

@ -2,6 +2,7 @@ from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
@ -19,7 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
@ -65,6 +68,7 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
new_tool_call_chunk = process_llm_stream(
stream,
True,
writer,
final_search_results=final_search_results,
displayed_search_results=initial_search_results,
)

View File

@ -3,6 +3,7 @@ from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
@ -24,7 +25,11 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
def llm_tool_choice(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
@ -97,7 +102,9 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic
)
tool_message = process_llm_stream(
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
stream,
should_stream_answer and not agent_config.skip_gen_ai_answer_generation,
writer,
)
# If no tool calls are emitted by the LLM, we should not choose a tool

View File

@ -1,14 +1,15 @@
from typing import cast
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
@ -19,11 +20,15 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def emit_packet(packet: AnswerPacket) -> None:
dispatch_custom_event("basic_response", packet)
def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer)
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
def tool_call(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(AgentSearchConfig, config["metadata"]["config"])
@ -38,15 +43,15 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff)
emit_packet(tool_kickoff, writer)
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
emit_packet(response)
emit_packet(response, writer)
tool_final_result = tool_runner.tool_final_result()
emit_packet(tool_final_result)
emit_packet(tool_final_result, writer)
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(

View File

@ -4,6 +4,7 @@ from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import CustomStreamEvent
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
@ -144,11 +145,8 @@ def manage_sync_streaming(
stream_mode="custom",
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
):
print(event)
return []
yield cast(CustomStreamEvent, event)
def run_graph(
@ -159,7 +157,7 @@ def run_graph(
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
config.allow_refinement = ALLOW_REFINEMENT
for event in _manage_async_event_streaming(
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):

View File

@ -8,11 +8,13 @@ from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from typing import Literal
from typing import TypedDict
from uuid import UUID
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
@ -30,6 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
@ -57,6 +60,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
BaseMessage_Content = str | list[str | dict[str, Any]]
@ -314,13 +318,13 @@ def dispatch_separated(
return streamed_tokens
def dispatch_main_answer_stop_info(level: int) -> None:
def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="main_answer",
level=level,
)
dispatch_custom_event("stream_finished", stop_event)
write_custom_event("stream_finished", stop_event, writer)
def get_today_prompt() -> str:
@ -368,3 +372,22 @@ def summarize_history(
history_context_response_str = ""
return history_context_response_str
# taken from langchain_core.runnables.schema
# we don't use the one from their library because
# it includes ids they generate
class CustomStreamEvent(TypedDict):
# Overwrite the event field to be more specific.
event: Literal["on_custom_event"] # type: ignore[misc]
"""The event type."""
name: str
"""User defined name for the event."""
data: Any
"""The data associated with the event. Free form and can be anything."""
def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter
) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))