fix for early cancellation test; solves issue with tasks being destroyed while pending

This commit is contained in:
Evan Lohn
2025-01-22 21:15:56 -08:00
parent deea9c8c3c
commit 3ced9bc28b

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from asyncio import AbstractEventLoop
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from collections.abc import Iterable from collections.abc import Iterable
from datetime import datetime from datetime import datetime
@@ -79,17 +78,9 @@ def _parse_agent_event(
return None return None
async def tear_down(event_loop: AbstractEventLoop) -> None: # https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# Collect all tasks and cancel those that are not 'done'. # https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
tasks = asyncio.all_tasks(event_loop) task_references: set[asyncio.Task[StreamEvent]] = set()
for task in tasks:
task.cancel()
# Wait for all tasks to complete, ignoring any CancelledErrors
try:
await asyncio.wait(tasks)
except asyncio.exceptions.CancelledError:
pass
def _manage_async_event_streaming( def _manage_async_event_streaming(
@@ -97,40 +88,39 @@ def _manage_async_event_streaming(
config: AgentSearchConfig | None, config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput, graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]: ) -> Iterable[StreamEvent]:
async def _run_async_event_stream( async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
loop: AbstractEventLoop, message_id = config.message_id if config else None
) -> AsyncIterable[StreamEvent]: async for event in compiled_graph.astream_events(
try: input=graph_input,
message_id = config.message_id if config else None config={"metadata": {"config": config, "thread_id": str(message_id)}},
async for event in compiled_graph.astream_events( # debug=True,
input=graph_input, # indicating v2 here deserves further scrutiny
config={"metadata": {"config": config, "thread_id": str(message_id)}}, version="v2",
# debug=True, ):
# indicating v2 here deserves further scrutiny yield event
version="v2",
):
yield event
finally:
await tear_down(loop)
# This might be able to be simplified # This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]: def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
try: try:
# Get the async generator # Get the async generator
async_gen = _run_async_event_stream(loop) async_gen = _run_async_event_stream()
# Convert to AsyncIterator # Convert to AsyncIterator
async_iter = async_gen.__aiter__() async_iter = async_gen.__aiter__()
while True: while True:
try: try:
# Create a coroutine by calling anext with the async iterator # Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter) next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event # Run the coroutine to get the next event
event = loop.run_until_complete(next_coro) event = loop.run_until_complete(task)
yield event yield event
except StopAsyncIteration: except (StopAsyncIteration, GeneratorExit):
break break
finally: finally:
for task in task_references.pop():
task.cancel()
loop.close() loop.close()
return _yield_async_to_sync() return _yield_async_to_sync()