From b66a8b3de91cd6018c03e79943d3e74acd484f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Mon, 16 Dec 2024 09:45:39 +0100 Subject: [PATCH] fix: make startup extension check sync (#2819) --- .github/workflows/jmeter.yml | 2 +- lnbits/app.py | 27 +++++++++++------- lnbits/core/services/extensions.py | 46 +++++++++++++++++++++++++----- 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/.github/workflows/jmeter.yml b/.github/workflows/jmeter.yml index cb53db8a0..c3f2a5d38 100644 --- a/.github/workflows/jmeter.yml +++ b/.github/workflows/jmeter.yml @@ -31,7 +31,7 @@ jobs: LNBITS_BACKEND_WALLET_CLASS: FakeWallet run: | poetry run lnbits & - sleep 5 + sleep 10 - name: setup java version run: | diff --git a/lnbits/app.py b/lnbits/app.py index d3aec7582..8b2091175 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -65,7 +65,6 @@ from .middleware import ( from .requestvars import g from .tasks import ( check_pending_payments, - create_task, internal_invoice_listener, invoice_listener, ) @@ -81,6 +80,10 @@ async def startup(app: FastAPI): await check_admin_settings() await check_webpush_settings() + # check extensions after restart + if not settings.lnbits_extensions_deactivate_all: + await check_and_register_extensions(app) + log_server_info() # initialize WALLET @@ -97,7 +100,7 @@ async def startup(app: FastAPI): init_core_routers(app) # initialize tasks - register_async_tasks(app) + register_async_tasks() async def shutdown(): @@ -395,16 +398,21 @@ def register_new_ratelimiter(app: FastAPI) -> Callable: return register_new_ratelimiter_fn +def register_ext_tasks(ext: Extension) -> None: + """Register extension async tasks.""" + ext_module = importlib.import_module(ext.module_name) + + if hasattr(ext_module, f"{ext.code}_start"): + ext_start_func = getattr(ext_module, f"{ext.code}_start") + ext_start_func() + + def register_ext_routes(app: FastAPI, ext: Extension) -> None: """Register FastAPI routes for extension.""" ext_module = importlib.import_module(ext.module_name) ext_route = getattr(ext_module, f"{ext.code}_ext") - if hasattr(ext_module, f"{ext.code}_start"): - ext_start_func = getattr(ext_module, f"{ext.code}_start") - ext_start_func() - if hasattr(ext_module, f"{ext.code}_static_files"): ext_statics = getattr(ext_module, f"{ext.code}_static_files") for s in ext_statics: @@ -431,15 +439,12 @@ async def check_and_register_extensions(app: FastAPI): for ext in await get_valid_extensions(False): try: register_ext_routes(app, ext) + register_ext_tasks(ext) except Exception as exc: logger.error(f"Could not load extension `{ext.code}`: {exc!s}") -def register_async_tasks(app: FastAPI): - - # check extensions after restart - if not settings.lnbits_extensions_deactivate_all: - create_task(check_and_register_extensions(app)) +def register_async_tasks(): create_permanent_task(wait_for_audit_data) create_permanent_task(check_pending_payments) diff --git a/lnbits/core/services/extensions.py b/lnbits/core/services/extensions.py index 79cfd3603..a30c0d46e 100644 --- a/lnbits/core/services/extensions.py +++ b/lnbits/core/services/extensions.py @@ -49,6 +49,8 @@ async def install_extension(ext_info: InstallableExtension) -> Extension: # call stop while the old routes are still active await stop_extension_background_work(ext_id) + await start_extension_background_work(ext_id) + return extension @@ -76,7 +78,7 @@ async def deactivate_extension(ext_id: str): async def stop_extension_background_work(ext_id: str) -> bool: """ Stop background work for extension (like asyncio.Tasks, WebSockets, etc). - Extensions SHOULD expose a `api_stop()` function. + Extension must expose a `myextension_stop()` function if it is starting tasks. """ upgrade_hash = settings.extension_upgrade_hash(ext_id) ext = Extension(code=ext_id, is_valid=True, upgrade_hash=upgrade_hash) @@ -85,11 +87,10 @@ async def stop_extension_background_work(ext_id: str) -> bool: logger.info(f"Stopping background work for extension '{ext.module_name}'.") old_module = importlib.import_module(ext.module_name) - # Extensions must expose an `{ext_id}_stop()` function at the module level - # The `api_stop()` function is for backwards compatibility (will be deprecated) - stop_fns = [f"{ext_id}_stop", "api_stop"] - stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None) - assert stop_fn_name, f"No stop function found for '{ext.module_name}'." + stop_fn_name = f"{ext_id}_stop" + assert hasattr( + old_module, stop_fn_name + ), f"No stop function found for '{ext.module_name}'." stop_fn = getattr(old_module, stop_fn_name) if stop_fn: @@ -97,7 +98,6 @@ async def stop_extension_background_work(ext_id: str) -> bool: await stop_fn() else: stop_fn() - logger.info(f"Stopped background work for extension '{ext.module_name}'.") except Exception as ex: logger.warning(f"Failed to stop background work for '{ext.module_name}'.") @@ -107,6 +107,38 @@ async def stop_extension_background_work(ext_id: str) -> bool: return True +async def start_extension_background_work(ext_id: str) -> bool: + """ + Start background work for extension (like asyncio.Tasks, WebSockets, etc). + Extension CAN expose a `myextension_start()` function if it is starting tasks. + Extension MUST expose a `myextension_stop()` in that case. + """ + upgrade_hash = settings.extension_upgrade_hash(ext_id) + ext = Extension(code=ext_id, is_valid=True, upgrade_hash=upgrade_hash) + + try: + logger.info(f"Starting background work for extension '{ext.module_name}'.") + new_module = importlib.import_module(ext.module_name) + start_fn_name = f"{ext_id}_start" + + # start function is optional, return False if not found + if not hasattr(new_module, start_fn_name): + return False + + start_fn = getattr(new_module, start_fn_name) + if start_fn: + if asyncio.iscoroutinefunction(start_fn): + await start_fn() + else: + start_fn() + logger.info(f"Started background work for extension '{ext.module_name}'.") + return True + except Exception as ex: + logger.warning(f"Failed to start background work for '{ext.module_name}'.") + logger.warning(ex) + return False + + async def get_valid_extensions( include_deactivated: Optional[bool] = True, ) -> list[Extension]: