mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-05 17:30:26 +02:00
218 lines
8.5 KiB
Python
218 lines
8.5 KiB
Python
from collections.abc import Iterable
|
|
from datetime import datetime
|
|
from typing import cast
|
|
|
|
from langchain_core.runnables.schema import CustomStreamEvent
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
|
from onyx.agents.agent_search.basic.states import BasicInput
|
|
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
|
main_graph_builder as main_graph_builder_a,
|
|
)
|
|
from onyx.agents.agent_search.deep_search.main.states import (
|
|
MainInput as MainInput_a,
|
|
)
|
|
from onyx.agents.agent_search.models import GraphConfig
|
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import AnswerPacket
|
|
from onyx.chat.models import AnswerStream
|
|
from onyx.chat.models import ExtendedToolResponse
|
|
from onyx.chat.models import RefinedAnswerImprovement
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import SubQueryPiece
|
|
from onyx.chat.models import SubQuestionPiece
|
|
from onyx.chat.models import ToolResponse
|
|
from onyx.configs.agent_configs import ALLOW_REFINEMENT
|
|
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.llm.factory import get_default_llms
|
|
from onyx.tools.tool_runner import ToolCallKickoff
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
|
|
|
|
|
def _parse_agent_event(
|
|
event: StreamEvent,
|
|
) -> AnswerPacket | None:
|
|
"""
|
|
Parse the event into a typed object.
|
|
Return None if we are not interested in the event.
|
|
"""
|
|
|
|
event_type = event["event"]
|
|
|
|
# We always just yield the event data, but this piece is useful for two development reasons:
|
|
# 1. It's a list of the names of every place we dispatch a custom event
|
|
# 2. We maintain the intended types yielded by each event
|
|
if event_type == "on_custom_event":
|
|
if event["name"] == "decomp_qs":
|
|
return cast(SubQuestionPiece, event["data"])
|
|
elif event["name"] == "subqueries":
|
|
return cast(SubQueryPiece, event["data"])
|
|
elif event["name"] == "sub_answers":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "stream_finished":
|
|
return cast(StreamStopInfo, event["data"])
|
|
elif event["name"] == "initial_agent_answer":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "refined_agent_answer":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "start_refined_answer_creation":
|
|
return cast(ToolCallKickoff, event["data"])
|
|
elif event["name"] == "tool_response":
|
|
return cast(ToolResponse, event["data"])
|
|
elif event["name"] == "basic_response":
|
|
return cast(AnswerPacket, event["data"])
|
|
elif event["name"] == "refined_answer_improvement":
|
|
return cast(RefinedAnswerImprovement, event["data"])
|
|
return None
|
|
|
|
|
|
def manage_sync_streaming(
|
|
compiled_graph: CompiledStateGraph,
|
|
config: GraphConfig,
|
|
graph_input: BasicInput | MainInput_a,
|
|
) -> Iterable[StreamEvent]:
|
|
message_id = config.persistence.message_id if config.persistence else None
|
|
for event in compiled_graph.stream(
|
|
stream_mode="custom",
|
|
input=graph_input,
|
|
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
|
):
|
|
yield cast(CustomStreamEvent, event)
|
|
|
|
|
|
def run_graph(
|
|
compiled_graph: CompiledStateGraph,
|
|
config: GraphConfig,
|
|
input: BasicInput | MainInput_a,
|
|
) -> AnswerStream:
|
|
config.behavior.perform_initial_search_decomposition = (
|
|
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
|
)
|
|
config.behavior.allow_refinement = ALLOW_REFINEMENT
|
|
|
|
for event in manage_sync_streaming(
|
|
compiled_graph=compiled_graph, config=config, graph_input=input
|
|
):
|
|
if not (parsed_object := _parse_agent_event(event)):
|
|
continue
|
|
|
|
yield parsed_object
|
|
|
|
|
|
# It doesn't actually take very long to load the graph, but we'd rather
|
|
# not compile it again on every request.
|
|
def load_compiled_graph() -> CompiledStateGraph:
|
|
global _COMPILED_GRAPH
|
|
if _COMPILED_GRAPH is None:
|
|
graph = main_graph_builder_a()
|
|
_COMPILED_GRAPH = graph.compile()
|
|
return _COMPILED_GRAPH
|
|
|
|
|
|
def run_main_graph(
|
|
config: GraphConfig,
|
|
) -> AnswerStream:
|
|
compiled_graph = load_compiled_graph()
|
|
|
|
input = MainInput_a(
|
|
base_question=config.inputs.search_request.query, log_messages=[]
|
|
)
|
|
|
|
# Agent search is not a Tool per se, but this is helpful for the frontend
|
|
yield ToolCallKickoff(
|
|
tool_name="agent_search_0",
|
|
tool_args={"query": config.inputs.search_request.query},
|
|
)
|
|
yield from run_graph(compiled_graph, config, input)
|
|
|
|
|
|
def run_basic_graph(
|
|
config: GraphConfig,
|
|
) -> AnswerStream:
|
|
graph = basic_graph_builder()
|
|
compiled_graph = graph.compile()
|
|
input = BasicInput()
|
|
return run_graph(compiled_graph, config, input)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for _ in range(1):
|
|
query_start_time = datetime.now()
|
|
logger.debug(f"Start at {query_start_time}")
|
|
graph = main_graph_builder_a()
|
|
compiled_graph = graph.compile()
|
|
query_end_time = datetime.now()
|
|
logger.debug(f"Graph compiled in {query_end_time - query_start_time} seconds")
|
|
primary_llm, fast_llm = get_default_llms()
|
|
search_request = SearchRequest(
|
|
# query="what can you do with gitlab?",
|
|
# query="What are the guiding principles behind the development of cockroachDB",
|
|
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
|
# query="When was Washington born?",
|
|
# query="What is Onyx?",
|
|
# query="What is the difference between astronomy and astrology?",
|
|
query="Do a search to tell me what is the difference between astronomy and astrology?",
|
|
)
|
|
# Joachim custom persona
|
|
|
|
with get_session_context_manager() as db_session:
|
|
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
|
|
assert (
|
|
config.persistence is not None
|
|
), "set a chat session id to run this test"
|
|
|
|
# search_request.persona = get_persona_by_id(1, None, db_session)
|
|
# config.perform_initial_search_path_decision = False
|
|
config.behavior.perform_initial_search_decomposition = True
|
|
input = MainInput_a(
|
|
base_question=config.inputs.search_request.query, log_messages=[]
|
|
)
|
|
|
|
# with open("output.txt", "w") as f:
|
|
tool_responses: list = []
|
|
for output in run_graph(compiled_graph, config, input):
|
|
# pass
|
|
|
|
if isinstance(output, ToolCallKickoff):
|
|
pass
|
|
elif isinstance(output, ExtendedToolResponse):
|
|
tool_responses.append(output.response)
|
|
logger.info(
|
|
f" ---- ET {output.level} - {output.level_question_num} | "
|
|
)
|
|
elif isinstance(output, SubQueryPiece):
|
|
logger.info(
|
|
f"Sq {output.level} - {output.level_question_num} - {output.sub_query} | "
|
|
)
|
|
elif isinstance(output, SubQuestionPiece):
|
|
logger.info(
|
|
f"SQ {output.level} - {output.level_question_num} - {output.sub_question} | "
|
|
)
|
|
elif (
|
|
isinstance(output, AgentAnswerPiece)
|
|
and output.answer_type == "agent_sub_answer"
|
|
):
|
|
logger.info(
|
|
f" ---- SA {output.level} - {output.level_question_num} {output.answer_piece} | "
|
|
)
|
|
elif (
|
|
isinstance(output, AgentAnswerPiece)
|
|
and output.answer_type == "agent_level_answer"
|
|
):
|
|
logger.info(
|
|
f" ---------- FA {output.level} - {output.level_question_num} {output.answer_piece} | "
|
|
)
|
|
elif isinstance(output, RefinedAnswerImprovement):
|
|
logger.info(
|
|
f" ---------- RE {output.refined_answer_improvement} | "
|
|
)
|