cancel all long-running tasks (#1793)

* add centralized task management

in order to properly cleanup all long-running tasks we have to keep a list of them

* use new task management functions

* unify shutdown events

* vlads suggestions

rename variable for create_task
wrap cancel() with try/catch

fixup

* rename func to coro

---------

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
jackstar12
2023-08-18 11:25:33 +02:00
committed by GitHub
parent 65db43ace4
commit 1fd4d9d514
3 changed files with 46 additions and 52 deletions

View File

@@ -28,9 +28,9 @@ from lnbits.core.services import websocketUpdater
from lnbits.core.tasks import ( # register_watchdog,; unregister_watchdog, from lnbits.core.tasks import ( # register_watchdog,; unregister_watchdog,
register_killswitch, register_killswitch,
register_task_listeners, register_task_listeners,
unregister_killswitch,
) )
from lnbits.settings import settings from lnbits.settings import settings
from lnbits.tasks import cancel_all_tasks, create_permanent_task
from lnbits.wallets import get_wallet_class, set_wallet_class from lnbits.wallets import get_wallet_class, set_wallet_class
from .commands import db_versions, load_disabled_extension_list, migrate_databases from .commands import db_versions, load_disabled_extension_list, migrate_databases
@@ -52,7 +52,6 @@ from .middleware import (
) )
from .requestvars import g from .requestvars import g
from .tasks import ( from .tasks import (
catch_everything_and_restart,
check_pending_payments, check_pending_payments,
internal_invoice_listener, internal_invoice_listener,
invoice_listener, invoice_listener,
@@ -366,6 +365,9 @@ def register_startup(app: FastAPI):
def register_shutdown(app: FastAPI): def register_shutdown(app: FastAPI):
@app.on_event("shutdown") @app.on_event("shutdown")
async def on_shutdown(): async def on_shutdown():
cancel_all_tasks()
# wait a bit to allow them to finish, so that cleanup can run without problems
await asyncio.sleep(0.1)
WALLET = get_wallet_class() WALLET = get_wallet_class()
await WALLET.cleanup() await WALLET.cleanup()
@@ -380,7 +382,7 @@ def initialize_server_logger():
msg = await serverlog_queue.get() msg = await serverlog_queue.get()
await websocketUpdater(super_user_hash, msg) await websocketUpdater(super_user_hash, msg)
asyncio.create_task(update_websocket_serverlog()) create_permanent_task(update_websocket_serverlog)
logger.add( logger.add(
lambda msg: serverlog_queue.put_nowait(msg), lambda msg: serverlog_queue.put_nowait(msg),
@@ -421,21 +423,13 @@ def register_async_tasks(app):
@app.on_event("startup") @app.on_event("startup")
async def listeners(): async def listeners():
loop = asyncio.get_event_loop() create_permanent_task(check_pending_payments)
loop.create_task(catch_everything_and_restart(check_pending_payments)) create_permanent_task(invoice_listener)
loop.create_task(catch_everything_and_restart(invoice_listener)) create_permanent_task(internal_invoice_listener)
loop.create_task(catch_everything_and_restart(internal_invoice_listener)) register_task_listeners()
await register_task_listeners() register_killswitch()
# await register_watchdog()
await register_killswitch()
# await run_deferred_async() # calle: doesn't do anyting? # await run_deferred_async() # calle: doesn't do anyting?
@app.on_event("shutdown")
async def stop_listeners():
# await unregister_watchdog()
await unregister_killswitch()
pass
def register_exception_handlers(app: FastAPI): def register_exception_handlers(app: FastAPI):
@app.exception_handler(Exception) @app.exception_handler(Exception)

View File

@@ -1,11 +1,16 @@
import asyncio import asyncio
from typing import Dict, Optional from typing import Dict
import httpx import httpx
from loguru import logger from loguru import logger
from lnbits.settings import get_wallet_class, settings from lnbits.settings import get_wallet_class, settings
from lnbits.tasks import SseListenersDict, register_invoice_listener from lnbits.tasks import (
SseListenersDict,
create_permanent_task,
create_task,
register_invoice_listener,
)
from . import db from . import db
from .crud import get_balance_notify, get_wallet from .crud import get_balance_notify, get_wallet
@@ -16,28 +21,14 @@ api_invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict(
"api_invoice_listeners" "api_invoice_listeners"
) )
killswitch: Optional[asyncio.Task] = None
watchdog: Optional[asyncio.Task] = None
def register_killswitch():
async def register_killswitch():
""" """
Registers a killswitch which will check lnbits-status repository Registers a killswitch which will check lnbits-status repository for a signal from
for a signal from LNbits and will switch to VoidWallet if the killswitch is triggered. LNbits and will switch to VoidWallet if the killswitch is triggered.
""" """
logger.debug("Starting killswitch task") logger.debug("Starting killswitch task")
global killswitch create_permanent_task(killswitch_task)
killswitch = asyncio.create_task(killswitch_task())
async def unregister_killswitch():
"""
Unregisters a killswitch taskl
"""
global killswitch
if killswitch:
logger.debug("Stopping killswitch task")
killswitch.cancel()
async def killswitch_task(): async def killswitch_task():
@@ -67,20 +58,9 @@ async def register_watchdog():
Registers a watchdog which will check lnbits balance and nodebalance Registers a watchdog which will check lnbits balance and nodebalance
and will switch to VoidWallet if the watchdog delta is reached. and will switch to VoidWallet if the watchdog delta is reached.
""" """
# TODO: implement watchdog porperly # TODO: implement watchdog properly
# logger.debug("Starting watchdog task") # logger.debug("Starting watchdog task")
# global watchdog # create_permanent_task(watchdog_task)
# watchdog = asyncio.create_task(watchdog_task())
async def unregister_watchdog():
"""
Unregisters a watchdog task
"""
global watchdog
if watchdog:
logger.debug("Stopping watchdog task")
watchdog.cancel()
async def watchdog_task(): async def watchdog_task():
@@ -98,7 +78,7 @@ async def watchdog_task():
await asyncio.sleep(settings.lnbits_watchdog_interval * 60) await asyncio.sleep(settings.lnbits_watchdog_interval * 60)
async def register_task_listeners(): def register_task_listeners():
""" """
Registers an invoice listener queue for the core tasks. Registers an invoice listener queue for the core tasks.
Incoming payaments in this queue will eventually trigger the signals sent to all other extensions Incoming payaments in this queue will eventually trigger the signals sent to all other extensions
@@ -108,7 +88,7 @@ async def register_task_listeners():
# we register invoice_paid_queue to receive all incoming invoices # we register invoice_paid_queue to receive all incoming invoices
register_invoice_listener(invoice_paid_queue, "core/tasks.py") register_invoice_listener(invoice_paid_queue, "core/tasks.py")
# register a worker that will react to invoices # register a worker that will react to invoices
asyncio.create_task(wait_for_paid_invoices(invoice_paid_queue)) create_task(wait_for_paid_invoices(invoice_paid_queue))
async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue): async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue):

View File

@@ -3,7 +3,7 @@ import time
import traceback import traceback
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, Optional from typing import Dict, List, Optional
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from loguru import logger from loguru import logger
@@ -19,6 +19,26 @@ from lnbits.wallets import get_wallet_class
from .core import db from .core import db
tasks: List[asyncio.Task] = []
def create_task(coro):
task = asyncio.create_task(coro)
tasks.append(task)
return task
def create_permanent_task(func):
return create_task(catch_everything_and_restart(func))
def cancel_all_tasks():
for task in tasks:
try:
task.cancel()
except Exception as exc:
logger.warning(f"error while cancelling task: {str(exc)}")
async def catch_everything_and_restart(func): async def catch_everything_and_restart(func):
try: try: