refactor: stream_llm_answer (#4772)

* refactor: stream_llm_answer

* fix: lambda

* fix: mypy, docstring
This commit is contained in:
Rei Meguro
2025-05-26 15:29:33 -07:00
committed by GitHub
parent 23ff3476bc
commit 0b01d7f848
5 changed files with 127 additions and 143 deletions

View File

@ -1,4 +1,3 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
@ -12,6 +11,7 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
@ -113,42 +113,20 @@ def consolidate_research(
)
]
dispatch_timings: list[float] = []
primary_model = graph_config.tooling.primary_llm
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
_ = run_with_timeout(
60,
stream_initial_answer,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
except Exception as e:

View File

@ -30,6 +30,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
@ -112,44 +113,23 @@ def generate_sub_answer(
config=fast_llm.config,
)
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []
def stream_sub_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
response, _ = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
lambda: stream_llm_answer(
llm=fast_llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=level,
agent_answer_question_num=question_num,
agent_answer_type="agent_sub_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
),
)
except (LLMTimeoutError, TimeoutError):

View File

@ -37,6 +37,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
@ -275,46 +276,24 @@ def generate_initial_answer(
agent_error: AgentErrorLog | None = None
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):

View File

@ -40,6 +40,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
@ -63,7 +64,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
@ -301,45 +301,24 @@ def generate_validate_refined_answer(
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens
try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="refined_agent_answer",
writer=writer,
agent_answer_level=1,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):

View File

@ -0,0 +1,68 @@
from datetime import datetime
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.llm.interfaces import LLM
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
event_name: str,
writer: StreamWriter,
agent_answer_level: int,
agent_answer_question_num: int,
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> tuple[list[str], list[float]]:
"""Stream the initial answer from the LLM.
Args:
llm: The LLM to use.
prompt: The prompt to use.
event_name: The name of the event to write.
writer: The writer to write to.
agent_answer_level: The level of the agent answer.
agent_answer_question_num: The question number within the level.
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
timeout_override: The LLM timeout to use.
max_tokens: The LLM max tokens to use.
Returns:
A tuple of the response and the dispatch timings.
"""
response: list[str] = []
dispatch_timings: list[float] = []
for message in llm.stream(
prompt, timeout_override=timeout_override, max_tokens=max_tokens
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
event_name,
AgentAnswerPiece(
answer_piece=content,
level=agent_answer_level,
level_question_num=agent_answer_question_num,
answer_type=agent_answer_type,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
response.append(content)
return response, dispatch_timings