mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
fix for early cancellation test; solves issue with tasks being destroyed while pending
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -79,17 +78,9 @@ def _parse_agent_event(
|
||||
return None
|
||||
|
||||
|
||||
async def tear_down(event_loop: AbstractEventLoop) -> None:
|
||||
# Collect all tasks and cancel those that are not 'done'.
|
||||
tasks = asyncio.all_tasks(event_loop)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for all tasks to complete, ignoring any CancelledErrors
|
||||
try:
|
||||
await asyncio.wait(tasks)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
pass
|
||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
@@ -97,40 +88,39 @@ def _manage_async_event_streaming(
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream(
|
||||
loop: AbstractEventLoop,
|
||||
) -> AsyncIterable[StreamEvent]:
|
||||
try:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
finally:
|
||||
await tear_down(loop)
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream(loop)
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
||||
task_references.add(task)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(next_coro)
|
||||
event = loop.run_until_complete(task)
|
||||
yield event
|
||||
except StopAsyncIteration:
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
Reference in New Issue
Block a user