diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index bc8c1cd1fb71..c21c1baac620 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -1,5 +1,4 @@ import asyncio -from asyncio import AbstractEventLoop from collections.abc import AsyncIterable from collections.abc import Iterable from datetime import datetime @@ -79,17 +78,9 @@ def _parse_agent_event( return None -async def tear_down(event_loop: AbstractEventLoop) -> None: - # Collect all tasks and cancel those that are not 'done'. - tasks = asyncio.all_tasks(event_loop) - for task in tasks: - task.cancel() - - # Wait for all tasks to complete, ignoring any CancelledErrors - try: - await asyncio.wait(tasks) - except asyncio.exceptions.CancelledError: - pass +# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator +# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks +task_references: set[asyncio.Task[StreamEvent]] = set() def _manage_async_event_streaming( @@ -97,40 +88,39 @@ def _manage_async_event_streaming( config: AgentSearchConfig | None, graph_input: MainInput_a | BasicInput, ) -> Iterable[StreamEvent]: - async def _run_async_event_stream( - loop: AbstractEventLoop, - ) -> AsyncIterable[StreamEvent]: - try: - message_id = config.message_id if config else None - async for event in compiled_graph.astream_events( - input=graph_input, - config={"metadata": {"config": config, "thread_id": str(message_id)}}, - # debug=True, - # indicating v2 here deserves further scrutiny - version="v2", - ): - yield event - finally: - await tear_down(loop) + async def _run_async_event_stream() -> AsyncIterable[StreamEvent]: + message_id = config.message_id if config else None + async for event in compiled_graph.astream_events( + input=graph_input, + config={"metadata": {"config": config, "thread_id": str(message_id)}}, + # debug=True, + # indicating v2 here deserves further scrutiny + version="v2", + ): + yield event # This might be able to be simplified def _yield_async_to_sync() -> Iterable[StreamEvent]: loop = asyncio.new_event_loop() try: # Get the async generator - async_gen = _run_async_event_stream(loop) + async_gen = _run_async_event_stream() # Convert to AsyncIterator async_iter = async_gen.__aiter__() while True: try: # Create a coroutine by calling anext with the async iterator next_coro = anext(async_iter) + task = asyncio.ensure_future(next_coro, loop=loop) + task_references.add(task) # Run the coroutine to get the next event - event = loop.run_until_complete(next_coro) + event = loop.run_until_complete(task) yield event - except StopAsyncIteration: + except (StopAsyncIteration, GeneratorExit): break finally: + for task in task_references.pop(): + task.cancel() loop.close() return _yield_async_to_sync()