mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-25 03:00:14 +02:00
281 lines
11 KiB
Python
281 lines
11 KiB
Python
import asyncio
|
|
from collections.abc import AsyncIterable
|
|
from collections.abc import Iterable
|
|
from datetime import datetime
|
|
from typing import cast
|
|
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
|
|
from onyx.agents.agent_search.basic.states import BasicInput
|
|
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
|
|
main_graph_builder as main_graph_builder_a,
|
|
)
|
|
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
|
|
MainInput as MainInput_a,
|
|
)
|
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import AnswerPacket
|
|
from onyx.chat.models import AnswerStream
|
|
from onyx.chat.models import ExtendedToolResponse
|
|
from onyx.chat.models import RefinedAnswerImprovement
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import SubQueryPiece
|
|
from onyx.chat.models import SubQuestionPiece
|
|
from onyx.chat.models import ToolResponse
|
|
from onyx.configs.agent_configs import GRAPH_VERSION_NAME
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.tools.tool_runner import ToolCallKickoff
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
|
|
|
|
|
def _set_combined_token_value(
|
|
combined_token: str, parsed_object: AgentAnswerPiece
|
|
) -> AgentAnswerPiece:
|
|
parsed_object.answer_piece = combined_token
|
|
|
|
return parsed_object
|
|
|
|
|
|
def _parse_agent_event(
|
|
event: StreamEvent,
|
|
) -> AnswerPacket | None:
|
|
"""
|
|
Parse the event into a typed object.
|
|
Return None if we are not interested in the event.
|
|
"""
|
|
|
|
event_type = event["event"]
|
|
|
|
# We always just yield the event data, but this piece is useful for two development reasons:
|
|
# 1. It's a list of the names of every place we dispatch a custom event
|
|
# 2. We maintain the intended types yielded by each event
|
|
if event_type == "on_custom_event":
|
|
# TODO: different AnswerStream types for different events
|
|
if event["name"] == "decomp_qs":
|
|
return cast(SubQuestionPiece, event["data"])
|
|
elif event["name"] == "subqueries":
|
|
return cast(SubQueryPiece, event["data"])
|
|
elif event["name"] == "sub_answers":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "stream_finished":
|
|
return cast(StreamStopInfo, event["data"])
|
|
elif event["name"] == "initial_agent_answer":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "refined_agent_answer":
|
|
return cast(AgentAnswerPiece, event["data"])
|
|
elif event["name"] == "start_refined_answer_creation":
|
|
return cast(ToolCallKickoff, event["data"])
|
|
elif event["name"] == "tool_response":
|
|
return cast(ToolResponse, event["data"])
|
|
elif event["name"] == "basic_response":
|
|
return cast(AnswerPacket, event["data"])
|
|
elif event["name"] == "refined_answer_improvement":
|
|
return cast(RefinedAnswerImprovement, event["data"])
|
|
return None
|
|
|
|
|
|
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
|
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
|
task_references: set[asyncio.Task[StreamEvent]] = set()
|
|
|
|
|
|
def _manage_async_event_streaming(
|
|
compiled_graph: CompiledStateGraph,
|
|
config: AgentSearchConfig | None,
|
|
graph_input: MainInput_a | BasicInput,
|
|
) -> Iterable[StreamEvent]:
|
|
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
|
message_id = config.message_id if config else None
|
|
async for event in compiled_graph.astream_events(
|
|
input=graph_input,
|
|
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
|
# debug=True,
|
|
# indicating v2 here deserves further scrutiny
|
|
version="v2",
|
|
):
|
|
yield event
|
|
|
|
# This might be able to be simplified
|
|
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
# Get the async generator
|
|
async_gen = _run_async_event_stream()
|
|
# Convert to AsyncIterator
|
|
async_iter = async_gen.__aiter__()
|
|
while True:
|
|
try:
|
|
# Create a coroutine by calling anext with the async iterator
|
|
next_coro = anext(async_iter)
|
|
task = asyncio.ensure_future(next_coro, loop=loop)
|
|
task_references.add(task)
|
|
# Run the coroutine to get the next event
|
|
event = loop.run_until_complete(task)
|
|
yield event
|
|
except (StopAsyncIteration, GeneratorExit):
|
|
break
|
|
finally:
|
|
try:
|
|
for task in task_references.pop():
|
|
task.cancel()
|
|
except StopAsyncIteration:
|
|
pass
|
|
loop.close()
|
|
|
|
return _yield_async_to_sync()
|
|
|
|
|
|
def run_graph(
|
|
compiled_graph: CompiledStateGraph,
|
|
config: AgentSearchConfig,
|
|
input: BasicInput | MainInput_a,
|
|
) -> AnswerStream:
|
|
# TODO: add these to the environment
|
|
# config.perform_initial_search_path_decision = False
|
|
config.perform_initial_search_decomposition = True
|
|
config.allow_refinement = True
|
|
|
|
for event in _manage_async_event_streaming(
|
|
compiled_graph=compiled_graph, config=config, graph_input=input
|
|
):
|
|
if not (parsed_object := _parse_agent_event(event)):
|
|
continue
|
|
|
|
yield parsed_object
|
|
|
|
|
|
# TODO: call this once on startup, TBD where and if it should be gated based
|
|
# on dev mode or not
|
|
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
|
|
main_graph_builder = (
|
|
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
|
|
)
|
|
global _COMPILED_GRAPH
|
|
if _COMPILED_GRAPH is None:
|
|
graph = main_graph_builder()
|
|
_COMPILED_GRAPH = graph.compile()
|
|
return _COMPILED_GRAPH
|
|
|
|
|
|
def run_main_graph(
|
|
config: AgentSearchConfig,
|
|
graph_name: str = "a",
|
|
) -> AnswerStream:
|
|
compiled_graph = load_compiled_graph(graph_name)
|
|
if graph_name == "a":
|
|
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
|
else:
|
|
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
|
|
|
# Agent search is not a Tool per se, but this is helpful for the frontend
|
|
yield ToolCallKickoff(
|
|
tool_name="agent_search_0",
|
|
tool_args={"query": config.search_request.query},
|
|
)
|
|
yield from run_graph(compiled_graph, config, input)
|
|
|
|
|
|
# TODO: unify input types, especially prosearchconfig
|
|
def run_basic_graph(
|
|
config: AgentSearchConfig,
|
|
) -> AnswerStream:
|
|
graph = basic_graph_builder()
|
|
compiled_graph = graph.compile()
|
|
# TODO: unify basic input
|
|
input = BasicInput()
|
|
return run_graph(compiled_graph, config, input)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from onyx.llm.factory import get_default_llms
|
|
|
|
for _ in range(1):
|
|
now_start = datetime.now()
|
|
logger.debug(f"Start at {now_start}")
|
|
|
|
if GRAPH_VERSION_NAME == "a":
|
|
graph = main_graph_builder_a()
|
|
else:
|
|
graph = main_graph_builder_a()
|
|
compiled_graph = graph.compile()
|
|
now_end = datetime.now()
|
|
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
|
primary_llm, fast_llm = get_default_llms()
|
|
search_request = SearchRequest(
|
|
# query="what can you do with gitlab?",
|
|
# query="What are the guiding principles behind the development of cockroachDB",
|
|
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
|
# query="When was Washington born?",
|
|
# query="What is Onyx?",
|
|
# query="What is the difference between astronomy and astrology?",
|
|
query="Do a search to tell me hat is the difference between astronomy and astrology?",
|
|
)
|
|
# Joachim custom persona
|
|
|
|
with get_session_context_manager() as db_session:
|
|
config, search_tool = get_test_config(
|
|
db_session, primary_llm, fast_llm, search_request
|
|
)
|
|
# search_request.persona = get_persona_by_id(1, None, db_session)
|
|
config.use_persistence = True
|
|
# config.perform_initial_search_path_decision = False
|
|
config.perform_initial_search_decomposition = True
|
|
if GRAPH_VERSION_NAME == "a":
|
|
input = MainInput_a(
|
|
base_question=config.search_request.query, log_messages=[]
|
|
)
|
|
else:
|
|
input = MainInput_a(
|
|
base_question=config.search_request.query, log_messages=[]
|
|
)
|
|
# with open("output.txt", "w") as f:
|
|
tool_responses: list = []
|
|
for output in run_graph(compiled_graph, config, input):
|
|
# pass
|
|
|
|
if isinstance(output, ToolCallKickoff):
|
|
pass
|
|
elif isinstance(output, ExtendedToolResponse):
|
|
tool_responses.append(output.response)
|
|
logger.info(
|
|
f" ---- ET {output.level} - {output.level_question_nr} | "
|
|
)
|
|
elif isinstance(output, SubQueryPiece):
|
|
logger.info(
|
|
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
|
|
)
|
|
elif isinstance(output, SubQuestionPiece):
|
|
logger.info(
|
|
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
|
|
)
|
|
elif (
|
|
isinstance(output, AgentAnswerPiece)
|
|
and output.answer_type == "agent_sub_answer"
|
|
):
|
|
logger.info(
|
|
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
|
)
|
|
elif (
|
|
isinstance(output, AgentAnswerPiece)
|
|
and output.answer_type == "agent_level_answer"
|
|
):
|
|
logger.info(
|
|
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
|
|
)
|
|
elif isinstance(output, RefinedAnswerImprovement):
|
|
logger.info(
|
|
f" ---------- RE {output.refined_answer_improvement} | "
|
|
)
|
|
|
|
# for tool_response in tool_responses:
|
|
# logger.debug(tool_response)
|