mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
refactor: stream_llm_answer (#4772)
* refactor: stream_llm_answer * fix: lambda * fix: mypy, docstring
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@ -12,6 +11,7 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
@ -113,42 +113,20 @@ def consolidate_research(
|
||||
)
|
||||
]
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
primary_model = graph_config.tooling.primary_llm
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
_ = run_with_timeout(
|
||||
60,
|
||||
stream_initial_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=30,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -30,6 +30,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
@ -112,44 +113,23 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
response, _ = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=fast_llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=level,
|
||||
agent_answer_question_num=question_num,
|
||||
agent_answer_type="agent_sub_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
@ -37,6 +37,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
@ -275,46 +276,24 @@ def generate_initial_answer(
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
streamed_tokens, dispatch_timings = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=model,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
@ -40,6 +40,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
@ -63,7 +64,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
@ -301,45 +301,24 @@ def generate_validate_refined_answer(
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
streamed_tokens, dispatch_timings = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=model,
|
||||
prompt=msg,
|
||||
event_name="refined_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=1,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
68
backend/onyx/agents/agent_search/shared_graph_utils/llm.py
Normal file
68
backend/onyx/agents/agent_search/shared_graph_utils/llm.py
Normal file
@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.llm.interfaces import LLM
|
||||
|
||||
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
event_name: str,
|
||||
writer: StreamWriter,
|
||||
agent_answer_level: int,
|
||||
agent_answer_question_num: int,
|
||||
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[list[str], list[float]]:
|
||||
"""Stream the initial answer from the LLM.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use.
|
||||
prompt: The prompt to use.
|
||||
event_name: The name of the event to write.
|
||||
writer: The writer to write to.
|
||||
agent_answer_level: The level of the agent answer.
|
||||
agent_answer_question_num: The question number within the level.
|
||||
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
|
||||
timeout_override: The LLM timeout to use.
|
||||
max_tokens: The LLM max tokens to use.
|
||||
|
||||
Returns:
|
||||
A tuple of the response and the dispatch timings.
|
||||
"""
|
||||
response: list[str] = []
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
for message in llm.stream(
|
||||
prompt, timeout_override=timeout_override, max_tokens=max_tokens
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
event_name,
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=agent_answer_level,
|
||||
level_question_num=agent_answer_question_num,
|
||||
answer_type=agent_answer_type,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
response.append(content)
|
||||
|
||||
return response, dispatch_timings
|
Reference in New Issue
Block a user