fix for early cancellation test; solves issue with tasks being destroyed while pending

This commit is contained in:
Evan Lohn
2025-01-22 21:15:56 -08:00
parent deea9c8c3c
commit 3ced9bc28b

View File

@@ -1,5 +1,4 @@
import asyncio
from asyncio import AbstractEventLoop
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
@@ -79,17 +78,9 @@ def _parse_agent_event(
return None
async def tear_down(event_loop: AbstractEventLoop) -> None:
# Collect all tasks and cancel those that are not 'done'.
tasks = asyncio.all_tasks(event_loop)
for task in tasks:
task.cancel()
# Wait for all tasks to complete, ignoring any CancelledErrors
try:
await asyncio.wait(tasks)
except asyncio.exceptions.CancelledError:
pass
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
task_references: set[asyncio.Task[StreamEvent]] = set()
def _manage_async_event_streaming(
@@ -97,40 +88,39 @@ def _manage_async_event_streaming(
config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream(
loop: AbstractEventLoop,
) -> AsyncIterable[StreamEvent]:
try:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
finally:
await tear_down(loop)
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream(loop)
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event
event = loop.run_until_complete(next_coro)
event = loop.run_until_complete(task)
yield event
except StopAsyncIteration:
except (StopAsyncIteration, GeneratorExit):
break
finally:
for task in task_references.pop():
task.cancel()
loop.close()
return _yield_async_to_sync()