joachim-danswer d70bbcc2ce k
2025-02-03 20:10:50 -08:00

281 lines
11 KiB
Python

import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.configs.agent_configs import GRAPH_VERSION_NAME
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _set_combined_token_value(
combined_token: str, parsed_object: AgentAnswerPiece
) -> AgentAnswerPiece:
parsed_object.answer_piece = combined_token
return parsed_object
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
# TODO: different AnswerStream types for different events
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "stream_finished":
return cast(StreamStopInfo, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "refined_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "start_refined_answer_creation":
return cast(ToolCallKickoff, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
return None
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
task_references: set[asyncio.Task[StreamEvent]] = set()
def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event
event = loop.run_until_complete(task)
yield event
except (StopAsyncIteration, GeneratorExit):
break
finally:
try:
for task in task_references.pop():
task.cancel()
except StopAsyncIteration:
pass
loop.close()
return _yield_async_to_sync()
def run_graph(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
# TODO: add these to the environment
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
config.allow_refinement = True
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):
continue
yield parsed_object
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
main_graph_builder = (
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
)
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: AgentSearchConfig,
graph_name: str = "a",
) -> AnswerStream:
compiled_graph = load_compiled_graph(graph_name)
if graph_name == "a":
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
else:
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
# TODO: unify input types, especially prosearchconfig
def run_basic_graph(
config: AgentSearchConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
# TODO: unify basic input
input = BasicInput()
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
from onyx.llm.factory import get_default_llms
for _ in range(1):
now_start = datetime.now()
logger.debug(f"Start at {now_start}")
if GRAPH_VERSION_NAME == "a":
graph = main_graph_builder_a()
else:
graph = main_graph_builder_a()
compiled_graph = graph.compile()
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me hat is the difference between astronomy and astrology?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
if GRAPH_VERSION_NAME == "a":
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
else:
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
# with open("output.txt", "w") as f:
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):
# pass
if isinstance(output, ToolCallKickoff):
pass
elif isinstance(output, ExtendedToolResponse):
tool_responses.append(output.response)
logger.info(
f" ---- ET {output.level} - {output.level_question_nr} | "
)
elif isinstance(output, SubQueryPiece):
logger.info(
f"Sq {output.level} - {output.level_question_nr} - {output.sub_query} | "
)
elif isinstance(output, SubQuestionPiece):
logger.info(
f"SQ {output.level} - {output.level_question_nr} - {output.sub_question} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_sub_answer"
):
logger.info(
f" ---- SA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_level_answer"
):
logger.info(
f" ---------- FA {output.level} - {output.level_question_nr} {output.answer_piece} | "
)
elif isinstance(output, RefinedAnswerImprovement):
logger.info(
f" ---------- RE {output.refined_answer_improvement} | "
)
# for tool_response in tool_responses:
# logger.debug(tool_response)