mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
sync streaming impl
This commit is contained in:
parent
4a0b2a6c09
commit
a340529de3
@ -1,10 +1,11 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
@ -20,6 +21,7 @@ logger = setup_logger()
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
@ -52,9 +54,10 @@ def process_llm_stream(
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message, []):
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"basic_response",
|
||||
response_part,
|
||||
writer,
|
||||
)
|
||||
|
||||
logger.info(f"Full answer: {full_answer}")
|
||||
|
@ -2,9 +2,9 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionState,
|
||||
@ -22,6 +22,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
@ -32,7 +33,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> QAGenerationUpdate:
|
||||
now_start = datetime.now()
|
||||
logger.info(f"--------{now_start}--------START ANSWER GENERATION---")
|
||||
@ -48,7 +51,7 @@ def generate_sub_answer(
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer_str,
|
||||
@ -56,6 +59,7 @@ def generate_sub_answer(
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Number of verified retrieval docs: {len(context_docs)}")
|
||||
@ -81,7 +85,7 @@ def generate_sub_answer(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
@ -89,6 +93,7 @@ def generate_sub_answer(
|
||||
level_question_nr=question_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
@ -112,7 +117,7 @@ def generate_sub_answer(
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
||||
|
@ -2,10 +2,10 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import (
|
||||
SearchSQState,
|
||||
@ -45,6 +45,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
@ -54,7 +55,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: SearchSQState, config: RunnableConfig
|
||||
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> InitialAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
@ -98,7 +99,7 @@ def generate_initial_answer(
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
@ -106,10 +107,11 @@ def generate_initial_answer(
|
||||
level=0,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(relevant_docs) == 0:
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=UNKNOWN_ANSWER,
|
||||
@ -117,8 +119,9 @@ def generate_initial_answer(
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
dispatch_main_answer_stop_info(0)
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
answer = UNKNOWN_ANSWER
|
||||
initial_agent_stats = InitialAgentResultStats(
|
||||
@ -197,7 +200,7 @@ def generate_initial_answer(
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
@ -205,6 +208,7 @@ def generate_initial_answer(
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
@ -216,7 +220,7 @@ def generate_initial_answer(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
dispatch_main_answer_stop_info(0)
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.initial.generate_initial_answer.states import (
|
||||
SearchSQState,
|
||||
@ -28,6 +28,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
@ -35,7 +36,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
|
||||
|
||||
def decompose_orig_question(
|
||||
state: SearchSQState, config: RunnableConfig
|
||||
state: SearchSQState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BaseDecompUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
@ -90,23 +91,26 @@ def decompose_orig_question(
|
||||
msg = [HumanMessage(content=decomposition_prompt)]
|
||||
|
||||
# Send the initial question as a subquestion with number 0
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=question,
|
||||
level=0,
|
||||
level_question_nr=0,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(0))
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(0, writer)
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="sub_questions",
|
||||
level=0,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
|
@ -1,19 +1,22 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import AnswerComparison
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
|
||||
|
||||
def compare_answers(state: MainState, config: RunnableConfig) -> AnswerComparison:
|
||||
def compare_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnswerComparison:
|
||||
now_start = datetime.now()
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@ -39,11 +42,12 @@ def compare_answers(state: MainState, config: RunnableConfig) -> AnswerCompariso
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
@ -1,10 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import (
|
||||
FollowUpSubQuestion,
|
||||
@ -29,15 +29,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FollowUpSubQuestionsUpdate:
|
||||
""" """
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
@ -46,6 +47,7 @@ def create_refined_sub_questions(
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
now_start = datetime.now()
|
||||
@ -90,7 +92,9 @@ def create_refined_sub_questions(
|
||||
# Grader
|
||||
model = agent_a_config.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(model.stream(msg), dispatch_subquestion(1))
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(1, writer)
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
|
@ -2,10 +2,10 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
@ -44,6 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
@ -52,7 +53,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedAnswerUpdate:
|
||||
now_start = datetime.now()
|
||||
|
||||
@ -102,7 +103,7 @@ def generate_refined_answer(
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
@ -110,6 +111,7 @@ def generate_refined_answer(
|
||||
level=1,
|
||||
level_question_nr=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(initial_documents) > 0:
|
||||
@ -228,7 +230,7 @@ def generate_refined_answer(
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
@ -236,6 +238,7 @@ def generate_refined_answer(
|
||||
level_question_nr=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
@ -244,7 +247,7 @@ def generate_refined_answer(
|
||||
logger.info(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
dispatch_main_answer_stop_info(1)
|
||||
dispatch_main_answer_stop_info(1, writer)
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
@ -9,6 +9,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
QuestionAnswerResults,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
@ -38,15 +39,18 @@ def remove_document_citations(text: str) -> str:
|
||||
return re.sub(r"\[\[(?:D|Q)\d+\]\]\(\)", "", text)
|
||||
|
||||
|
||||
def dispatch_subquestion(level: int) -> Callable[[str, int], None]:
|
||||
def dispatch_subquestion(
|
||||
level: int, writer: StreamWriter
|
||||
) -> Callable[[str, int], None]:
|
||||
def _helper(sub_question_part: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=sub_question_part,
|
||||
level=level,
|
||||
level_question_nr=num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return _helper
|
||||
|
@ -4,6 +4,7 @@ from typing import cast
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.operations import (
|
||||
dispatch_subquery,
|
||||
@ -26,7 +27,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
|
||||
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput, config: RunnableConfig
|
||||
state: ExpandedRetrievalInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> QueryExpansionUpdate:
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
@ -53,7 +56,7 @@ def expand_queries(
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr)
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_nr, writer)
|
||||
)
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.models import (
|
||||
ExpandedRetrievalResult,
|
||||
@ -18,12 +18,15 @@ from onyx.agents.agent_search.deep_search_a.shared.expanded_retrieval.states imp
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
|
||||
def format_results(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
state: ExpandedRetrievalState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ExpandedRetrievalUpdate:
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_infos = [
|
||||
@ -55,7 +58,7 @@ def format_results(
|
||||
get_section_relevance=lambda: None, # TODO: add relevance
|
||||
search_tool=agent_a_config.search_tool,
|
||||
):
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
@ -63,6 +66,7 @@ def format_results(
|
||||
level=level,
|
||||
level_question_nr=question_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state.verified_documents,
|
||||
|
@ -2,10 +2,11 @@ from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -13,9 +14,11 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None]:
|
||||
def dispatch_subquery(
|
||||
level: int, question_nr: int, writer: StreamWriter
|
||||
) -> Callable[[str, int], None]:
|
||||
def helper(token: str, num: int) -> None:
|
||||
dispatch_custom_event(
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=token,
|
||||
@ -23,6 +26,7 @@ def dispatch_subquery(level: int, question_nr: int) -> Callable[[str, int], None
|
||||
level_question_nr=question_nr,
|
||||
query_id=num,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return helper
|
||||
|
@ -2,6 +2,7 @@ from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
@ -19,7 +20,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicOutput:
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
@ -65,6 +68,7 @@ def basic_use_tool_response(state: BasicState, config: RunnableConfig) -> BasicO
|
||||
new_tool_call_chunk = process_llm_stream(
|
||||
stream,
|
||||
True,
|
||||
writer,
|
||||
final_search_results=final_search_results,
|
||||
displayed_search_results=initial_search_results,
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
@ -24,7 +25,11 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoiceUpdate:
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
@ -97,7 +102,9 @@ def llm_tool_choice(state: ToolChoiceState, config: RunnableConfig) -> ToolChoic
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream, should_stream_answer and not agent_config.skip_gen_ai_answer_generation
|
||||
stream,
|
||||
should_stream_answer and not agent_config.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
)
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
|
@ -1,14 +1,15 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
@ -19,11 +20,15 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def emit_packet(packet: AnswerPacket) -> None:
|
||||
dispatch_custom_event("basic_response", packet)
|
||||
def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate:
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
@ -38,15 +43,15 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff)
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
emit_packet(response)
|
||||
emit_packet(response, writer)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
emit_packet(tool_final_result)
|
||||
emit_packet(tool_final_result, writer)
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
|
@ -4,6 +4,7 @@ from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.schema import CustomStreamEvent
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
@ -144,11 +145,8 @@ def manage_sync_streaming(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
):
|
||||
print(event)
|
||||
|
||||
return []
|
||||
yield cast(CustomStreamEvent, event)
|
||||
|
||||
|
||||
def run_graph(
|
||||
@ -159,7 +157,7 @@ def run_graph(
|
||||
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
config.allow_refinement = ALLOW_REFINEMENT
|
||||
|
||||
for event in _manage_async_event_streaming(
|
||||
for event in manage_sync_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
):
|
||||
if not (parsed_object := _parse_agent_event(event)):
|
||||
|
@ -8,11 +8,13 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
@ -30,6 +32,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
HISTORY_CONTEXT_SUMMARY_PROMPT,
|
||||
)
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
@ -57,6 +60,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
@ -314,13 +318,13 @@ def dispatch_separated(
|
||||
return streamed_tokens
|
||||
|
||||
|
||||
def dispatch_main_answer_stop_info(level: int) -> None:
|
||||
def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type="main_answer",
|
||||
level=level,
|
||||
)
|
||||
dispatch_custom_event("stream_finished", stop_event)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def get_today_prompt() -> str:
|
||||
@ -368,3 +372,22 @@ def summarize_history(
|
||||
history_context_response_str = ""
|
||||
|
||||
return history_context_response_str
|
||||
|
||||
|
||||
# taken from langchain_core.runnables.schema
|
||||
# we don't use the one from their library because
|
||||
# it includes ids they generate
|
||||
class CustomStreamEvent(TypedDict):
|
||||
# Overwrite the event field to be more specific.
|
||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||
"""The event type."""
|
||||
name: str
|
||||
"""User defined name for the event."""
|
||||
data: Any
|
||||
"""The data associated with the event. Free form and can be anything."""
|
||||
|
||||
|
||||
def write_custom_event(
|
||||
name: str, event: AnswerPacket, stream_writer: StreamWriter
|
||||
) -> None:
|
||||
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
|
||||
|
Loading…
x
Reference in New Issue
Block a user