From 0b01d7f848fa20c8b6543725a789c1abdb9b3be7 Mon Sep 17 00:00:00 2001 From: Rei Meguro <36625832+Orbital-Web@users.noreply.github.com> Date: Mon, 26 May 2025 15:29:33 -0700 Subject: [PATCH] refactor: stream_llm_answer (#4772) * refactor: stream_llm_answer * fix: lambda * fix: mypy, docstring --- .../nodes/a5_consolidate_research.py | 46 ++++--------- .../nodes/generate_sub_answer.py | 46 ++++--------- .../nodes/generate_initial_answer.py | 55 +++++---------- .../nodes/generate_validate_refined_answer.py | 55 +++++---------- .../agent_search/shared_graph_utils/llm.py | 68 +++++++++++++++++++ 5 files changed, 127 insertions(+), 143 deletions(-) create mode 100644 backend/onyx/agents/agent_search/shared_graph_utils/llm.py diff --git a/backend/onyx/agents/agent_search/dc_search_analysis/nodes/a5_consolidate_research.py b/backend/onyx/agents/agent_search/dc_search_analysis/nodes/a5_consolidate_research.py index 1ac89ab0e44..6b92af02777 100644 --- a/backend/onyx/agents/agent_search/dc_search_analysis/nodes/a5_consolidate_research.py +++ b/backend/onyx/agents/agent_search/dc_search_analysis/nodes/a5_consolidate_research.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import cast from langchain_core.messages import HumanMessage @@ -12,6 +11,7 @@ from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( trim_prompt_piece, ) +from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event from onyx.chat.models import AgentAnswerPiece from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT @@ -113,42 +113,20 @@ def consolidate_research( ) ] - dispatch_timings: list[float] = [] - - primary_model = graph_config.tooling.primary_llm - - def stream_initial_answer() -> list[str]: - response: list[str] = [] - for message in primary_model.stream(msg, timeout_override=30, max_tokens=None): - # TODO: in principle, the answer here COULD contain images, but we don't support that yet - content = message.content - if not isinstance(content, str): - raise ValueError( - f"Expected content to be a string, but got {type(content)}" - ) - start_stream_token = datetime.now() - - write_custom_event( - "initial_agent_answer", - AgentAnswerPiece( - answer_piece=content, - level=0, - level_question_num=0, - answer_type="agent_level_answer", - ), - writer, - ) - end_stream_token = datetime.now() - dispatch_timings.append( - (end_stream_token - start_stream_token).microseconds - ) - response.append(content) - return response - try: _ = run_with_timeout( 60, - stream_initial_answer, + lambda: stream_llm_answer( + llm=graph_config.tooling.primary_llm, + prompt=msg, + event_name="initial_agent_answer", + writer=writer, + agent_answer_level=0, + agent_answer_question_num=0, + agent_answer_type="agent_level_answer", + timeout_override=30, + max_tokens=None, + ), ) except Exception as e: diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py index 3ba16972a42..5e33d004232 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_individual_sub_answer/nodes/generate_sub_answer.py @@ -30,6 +30,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import ( from onyx.agents.agent_search.shared_graph_utils.constants import ( LLM_ANSWER_ERROR_MESSAGE, ) +from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids @@ -112,44 +113,23 @@ def generate_sub_answer( config=fast_llm.config, ) - dispatch_timings: list[float] = [] agent_error: AgentErrorLog | None = None response: list[str] = [] - def stream_sub_answer() -> list[str]: - for message in fast_llm.stream( - prompt=msg, - timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION, - max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION, - ): - # TODO: in principle, the answer here COULD contain images, but we don't support that yet - content = message.content - if not isinstance(content, str): - raise ValueError( - f"Expected content to be a string, but got {type(content)}" - ) - start_stream_token = datetime.now() - write_custom_event( - "sub_answers", - AgentAnswerPiece( - answer_piece=content, - level=level, - level_question_num=question_num, - answer_type="agent_sub_answer", - ), - writer, - ) - end_stream_token = datetime.now() - dispatch_timings.append( - (end_stream_token - start_stream_token).microseconds - ) - response.append(content) - return response - try: - response = run_with_timeout( + response, _ = run_with_timeout( AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION, - stream_sub_answer, + lambda: stream_llm_answer( + llm=fast_llm, + prompt=msg, + event_name="sub_answers", + writer=writer, + agent_answer_level=level, + agent_answer_question_num=question_num, + agent_answer_type="agent_sub_answer", + timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION, + max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION, + ), ) except (LLMTimeoutError, TimeoutError): diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py index 03c33592586..90a84a80737 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -37,6 +37,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import ( from onyx.agents.agent_search.shared_graph_utils.constants import ( AgentLLMErrorType, ) +from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings @@ -275,46 +276,24 @@ def generate_initial_answer( agent_error: AgentErrorLog | None = None - def stream_initial_answer() -> list[str]: - response: list[str] = [] - for message in model.stream( - msg, - timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION, - max_tokens=( - AGENT_MAX_TOKENS_ANSWER_GENERATION - if _should_restrict_tokens(model.config) - else None - ), - ): - # TODO: in principle, the answer here COULD contain images, but we don't support that yet - content = message.content - if not isinstance(content, str): - raise ValueError( - f"Expected content to be a string, but got {type(content)}" - ) - start_stream_token = datetime.now() - - write_custom_event( - "initial_agent_answer", - AgentAnswerPiece( - answer_piece=content, - level=0, - level_question_num=0, - answer_type="agent_level_answer", - ), - writer, - ) - end_stream_token = datetime.now() - dispatch_timings.append( - (end_stream_token - start_stream_token).microseconds - ) - response.append(content) - return response - try: - streamed_tokens = run_with_timeout( + streamed_tokens, dispatch_timings = run_with_timeout( AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION, - stream_initial_answer, + lambda: stream_llm_answer( + llm=model, + prompt=msg, + event_name="initial_agent_answer", + writer=writer, + agent_answer_level=0, + agent_answer_question_num=0, + agent_answer_type="agent_level_answer", + timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION, + max_tokens=( + AGENT_MAX_TOKENS_ANSWER_GENERATION + if _should_restrict_tokens(model.config) + else None + ), + ), ) except (LLMTimeoutError, TimeoutError): diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py index b2a1736211f..32f4d6ea693 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py @@ -40,6 +40,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import ( from onyx.agents.agent_search.shared_graph_utils.constants import ( AgentLLMErrorType, ) +from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats @@ -63,7 +64,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( remove_document_citations, ) from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event -from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import StreamingError from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM @@ -301,45 +301,24 @@ def generate_validate_refined_answer( dispatch_timings: list[float] = [] agent_error: AgentErrorLog | None = None - def stream_refined_answer() -> list[str]: - for message in model.stream( - msg, - timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION, - max_tokens=( - AGENT_MAX_TOKENS_ANSWER_GENERATION - if _should_restrict_tokens(model.config) - else None - ), - ): - # TODO: in principle, the answer here COULD contain images, but we don't support that yet - content = message.content - if not isinstance(content, str): - raise ValueError( - f"Expected content to be a string, but got {type(content)}" - ) - - start_stream_token = datetime.now() - write_custom_event( - "refined_agent_answer", - AgentAnswerPiece( - answer_piece=content, - level=1, - level_question_num=0, - answer_type="agent_level_answer", - ), - writer, - ) - end_stream_token = datetime.now() - dispatch_timings.append( - (end_stream_token - start_stream_token).microseconds - ) - streamed_tokens.append(content) - return streamed_tokens - try: - streamed_tokens = run_with_timeout( + streamed_tokens, dispatch_timings = run_with_timeout( AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION, - stream_refined_answer, + lambda: stream_llm_answer( + llm=model, + prompt=msg, + event_name="refined_agent_answer", + writer=writer, + agent_answer_level=1, + agent_answer_question_num=0, + agent_answer_type="agent_level_answer", + timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION, + max_tokens=( + AGENT_MAX_TOKENS_ANSWER_GENERATION + if _should_restrict_tokens(model.config) + else None + ), + ), ) except (LLMTimeoutError, TimeoutError): diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/llm.py b/backend/onyx/agents/agent_search/shared_graph_utils/llm.py new file mode 100644 index 00000000000..e11fb024a48 --- /dev/null +++ b/backend/onyx/agents/agent_search/shared_graph_utils/llm.py @@ -0,0 +1,68 @@ +from datetime import datetime +from typing import Literal + +from langchain.schema.language_model import LanguageModelInput +from langgraph.types import StreamWriter + +from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event +from onyx.chat.models import AgentAnswerPiece +from onyx.llm.interfaces import LLM + + +def stream_llm_answer( + llm: LLM, + prompt: LanguageModelInput, + event_name: str, + writer: StreamWriter, + agent_answer_level: int, + agent_answer_question_num: int, + agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"], + timeout_override: int | None = None, + max_tokens: int | None = None, +) -> tuple[list[str], list[float]]: + """Stream the initial answer from the LLM. + + Args: + llm: The LLM to use. + prompt: The prompt to use. + event_name: The name of the event to write. + writer: The writer to write to. + agent_answer_level: The level of the agent answer. + agent_answer_question_num: The question number within the level. + agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer"). + timeout_override: The LLM timeout to use. + max_tokens: The LLM max tokens to use. + + Returns: + A tuple of the response and the dispatch timings. + """ + response: list[str] = [] + dispatch_timings: list[float] = [] + + for message in llm.stream( + prompt, timeout_override=timeout_override, max_tokens=max_tokens + ): + # TODO: in principle, the answer here COULD contain images, but we don't support that yet + content = message.content + if not isinstance(content, str): + raise ValueError( + f"Expected content to be a string, but got {type(content)}" + ) + + start_stream_token = datetime.now() + write_custom_event( + event_name, + AgentAnswerPiece( + answer_piece=content, + level=agent_answer_level, + level_question_num=agent_answer_question_num, + answer_type=agent_answer_type, + ), + writer, + ) + end_stream_token = datetime.now() + + dispatch_timings.append((end_stream_token - start_stream_token).microseconds) + response.append(content) + + return response, dispatch_timings