From 1fd4d9d514560c4cbf21b9e82b09a65426c48ad1 Mon Sep 17 00:00:00 2001 From: jackstar12 <62219658+jackstar12@users.noreply.github.com> Date: Fri, 18 Aug 2023 11:25:33 +0200 Subject: [PATCH] cancel all long-running tasks (#1793) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add centralized task management in order to properly cleanup all long-running tasks we have to keep a list of them * use new task management functions * unify shutdown events * vlads suggestions rename variable for create_task wrap cancel() with try/catch fixup * rename func to coro --------- Co-authored-by: dni ⚡ --- lnbits/app.py | 26 +++++++++-------------- lnbits/core/tasks.py | 50 +++++++++++++------------------------------- lnbits/tasks.py | 22 ++++++++++++++++++- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/lnbits/app.py b/lnbits/app.py index 9bc62e496..4f5d7a005 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -28,9 +28,9 @@ from lnbits.core.services import websocketUpdater from lnbits.core.tasks import ( # register_watchdog,; unregister_watchdog, register_killswitch, register_task_listeners, - unregister_killswitch, ) from lnbits.settings import settings +from lnbits.tasks import cancel_all_tasks, create_permanent_task from lnbits.wallets import get_wallet_class, set_wallet_class from .commands import db_versions, load_disabled_extension_list, migrate_databases @@ -52,7 +52,6 @@ from .middleware import ( ) from .requestvars import g from .tasks import ( - catch_everything_and_restart, check_pending_payments, internal_invoice_listener, invoice_listener, @@ -366,6 +365,9 @@ def register_startup(app: FastAPI): def register_shutdown(app: FastAPI): @app.on_event("shutdown") async def on_shutdown(): + cancel_all_tasks() + # wait a bit to allow them to finish, so that cleanup can run without problems + await asyncio.sleep(0.1) WALLET = get_wallet_class() await WALLET.cleanup() @@ -380,7 +382,7 @@ def initialize_server_logger(): msg = await serverlog_queue.get() await websocketUpdater(super_user_hash, msg) - asyncio.create_task(update_websocket_serverlog()) + create_permanent_task(update_websocket_serverlog) logger.add( lambda msg: serverlog_queue.put_nowait(msg), @@ -421,21 +423,13 @@ def register_async_tasks(app): @app.on_event("startup") async def listeners(): - loop = asyncio.get_event_loop() - loop.create_task(catch_everything_and_restart(check_pending_payments)) - loop.create_task(catch_everything_and_restart(invoice_listener)) - loop.create_task(catch_everything_and_restart(internal_invoice_listener)) - await register_task_listeners() - # await register_watchdog() - await register_killswitch() + create_permanent_task(check_pending_payments) + create_permanent_task(invoice_listener) + create_permanent_task(internal_invoice_listener) + register_task_listeners() + register_killswitch() # await run_deferred_async() # calle: doesn't do anyting? - @app.on_event("shutdown") - async def stop_listeners(): - # await unregister_watchdog() - await unregister_killswitch() - pass - def register_exception_handlers(app: FastAPI): @app.exception_handler(Exception) diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index 4572f73e8..5f7f67bb6 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -1,11 +1,16 @@ import asyncio -from typing import Dict, Optional +from typing import Dict import httpx from loguru import logger from lnbits.settings import get_wallet_class, settings -from lnbits.tasks import SseListenersDict, register_invoice_listener +from lnbits.tasks import ( + SseListenersDict, + create_permanent_task, + create_task, + register_invoice_listener, +) from . import db from .crud import get_balance_notify, get_wallet @@ -16,28 +21,14 @@ api_invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict( "api_invoice_listeners" ) -killswitch: Optional[asyncio.Task] = None -watchdog: Optional[asyncio.Task] = None - -async def register_killswitch(): +def register_killswitch(): """ - Registers a killswitch which will check lnbits-status repository - for a signal from LNbits and will switch to VoidWallet if the killswitch is triggered. + Registers a killswitch which will check lnbits-status repository for a signal from + LNbits and will switch to VoidWallet if the killswitch is triggered. """ logger.debug("Starting killswitch task") - global killswitch - killswitch = asyncio.create_task(killswitch_task()) - - -async def unregister_killswitch(): - """ - Unregisters a killswitch taskl - """ - global killswitch - if killswitch: - logger.debug("Stopping killswitch task") - killswitch.cancel() + create_permanent_task(killswitch_task) async def killswitch_task(): @@ -67,20 +58,9 @@ async def register_watchdog(): Registers a watchdog which will check lnbits balance and nodebalance and will switch to VoidWallet if the watchdog delta is reached. """ - # TODO: implement watchdog porperly + # TODO: implement watchdog properly # logger.debug("Starting watchdog task") - # global watchdog - # watchdog = asyncio.create_task(watchdog_task()) - - -async def unregister_watchdog(): - """ - Unregisters a watchdog task - """ - global watchdog - if watchdog: - logger.debug("Stopping watchdog task") - watchdog.cancel() + # create_permanent_task(watchdog_task) async def watchdog_task(): @@ -98,7 +78,7 @@ async def watchdog_task(): await asyncio.sleep(settings.lnbits_watchdog_interval * 60) -async def register_task_listeners(): +def register_task_listeners(): """ Registers an invoice listener queue for the core tasks. Incoming payaments in this queue will eventually trigger the signals sent to all other extensions @@ -108,7 +88,7 @@ async def register_task_listeners(): # we register invoice_paid_queue to receive all incoming invoices register_invoice_listener(invoice_paid_queue, "core/tasks.py") # register a worker that will react to invoices - asyncio.create_task(wait_for_paid_invoices(invoice_paid_queue)) + create_task(wait_for_paid_invoices(invoice_paid_queue)) async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue): diff --git a/lnbits/tasks.py b/lnbits/tasks.py index c6e97e709..506880ee0 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -3,7 +3,7 @@ import time import traceback import uuid from http import HTTPStatus -from typing import Dict, Optional +from typing import Dict, List, Optional from fastapi.exceptions import HTTPException from loguru import logger @@ -19,6 +19,26 @@ from lnbits.wallets import get_wallet_class from .core import db +tasks: List[asyncio.Task] = [] + + +def create_task(coro): + task = asyncio.create_task(coro) + tasks.append(task) + return task + + +def create_permanent_task(func): + return create_task(catch_everything_and_restart(func)) + + +def cancel_all_tasks(): + for task in tasks: + try: + task.cancel() + except Exception as exc: + logger.warning(f"error while cancelling task: {str(exc)}") + async def catch_everything_and_restart(func): try: