diff --git a/lnbits/app.py b/lnbits/app.py index 1ed2ce29f..112739d78 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -12,7 +12,7 @@ from .core import core_app from .db import open_db, open_ext_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .proxy_fix import ASGIProxyFix -from .tasks import invoice_listener, webhook_handler, grab_app_for_later +from .tasks import run_deferred_async, invoice_listener, webhook_handler, grab_app_for_later secure_headers = SecureHeaders(hsts=False) @@ -111,6 +111,8 @@ def register_async_tasks(app): @app.before_serving async def listeners(): + run_deferred_async(app.nursery) + app.nursery.start_soon(invoice_listener) print("started global invoice_listener.") diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index 7ccaded9e..e863d0af2 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -8,8 +8,8 @@ core_app: Blueprint = Blueprint( from .views.api import * # noqa from .views.generic import * # noqa -from .tasks import on_invoice_paid +from .tasks import register_listeners -from lnbits.tasks import register_invoice_listener +from lnbits.tasks import record_async -register_invoice_listener("core", on_invoice_paid) +core_app.record(record_async(register_listeners)) diff --git a/lnbits/core/models.py b/lnbits/core/models.py index 243f9342a..c65bdf93c 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -70,6 +70,7 @@ class Payment(NamedTuple): preimage: str payment_hash: str extra: Dict + wallet_id: str @classmethod def from_row(cls, row: Row): @@ -84,6 +85,7 @@ class Payment(NamedTuple): fee=row["fee"], memo=row["memo"], time=row["time"], + wallet_id=row["wallet"], ) @property diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index 577682370..e0e28391f 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -1,15 +1,22 @@ import trio # type: ignore from typing import List -from .models import Payment +from lnbits.tasks import register_invoice_listener sse_listeners: List[trio.MemorySendChannel] = [] -async def on_invoice_paid(payment: Payment): - for send_channel in sse_listeners: - try: - send_channel.send_nowait(payment) - except trio.WouldBlock: - print("removing sse listener", send_channel) - sse_listeners.remove(send_channel) +async def register_listeners(): + invoice_paid_chan_send, invoice_paid_chan_recv = trio.open_memory_channel(5) + register_invoice_listener(invoice_paid_chan_send) + await wait_for_paid_invoices(invoice_paid_chan_recv) + + +async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel): + async for payment in invoice_paid_chan: + for send_channel in sse_listeners: + try: + send_channel.send_nowait(payment) + except trio.WouldBlock: + print("removing sse listener", send_channel) + sse_listeners.remove(send_channel) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 1aecbc490..39d1f1072 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -128,6 +128,7 @@ async def api_payment(payment_hash): @api_check_wallet_key("invoice") async def api_payments_sse(): g.db.close() + this_wallet_id = g.wallet.id send_payment, receive_payment = trio.open_memory_channel(0) @@ -138,7 +139,8 @@ async def api_payments_sse(): async def payment_received() -> None: async for payment in receive_payment: - await send_event.send(("payment", payment)) + if payment.wallet_id == this_wallet_id: + await send_event.send(("payment", payment)) async def repeat_keepalive(): await trio.sleep(1) @@ -160,7 +162,6 @@ async def api_payments_sse(): yield b"\n".join(message) + b"\r\n\r\n" except trio.Cancelled: - print("canceled!") return response = await make_response( diff --git a/lnbits/extensions/lnurlp/__init__.py b/lnbits/extensions/lnurlp/__init__.py index a4bc65f37..2c9a3e023 100644 --- a/lnbits/extensions/lnurlp/__init__.py +++ b/lnbits/extensions/lnurlp/__init__.py @@ -7,8 +7,8 @@ lnurlp_ext: Blueprint = Blueprint("lnurlp", __name__, static_folder="static", te from .views_api import * # noqa from .views import * # noqa from .lnurl import * # noqa -from .tasks import on_invoice_paid +from .tasks import register_listeners -from lnbits.tasks import register_invoice_listener +from lnbits.tasks import record_async -register_invoice_listener("lnurlp", on_invoice_paid) +lnurlp_ext.record(record_async(register_listeners)) diff --git a/lnbits/extensions/lnurlp/tasks.py b/lnbits/extensions/lnurlp/tasks.py index 37d245597..00e6da97e 100644 --- a/lnbits/extensions/lnurlp/tasks.py +++ b/lnbits/extensions/lnurlp/tasks.py @@ -1,10 +1,23 @@ +import trio # type: ignore import httpx from lnbits.core.models import Payment +from lnbits.tasks import run_on_pseudo_request, register_invoice_listener from .crud import get_pay_link_by_invoice, mark_webhook_sent +async def register_listeners(): + invoice_paid_chan_send, invoice_paid_chan_recv = trio.open_memory_channel(2) + register_invoice_listener(invoice_paid_chan_send) + await wait_for_paid_invoices(invoice_paid_chan_recv) + + +async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel): + async for payment in invoice_paid_chan: + await run_on_pseudo_request(on_invoice_paid, payment) + + async def on_invoice_paid(payment: Payment) -> None: islnurlp = "lnurlp" == payment.extra.get("tag") if islnurlp: diff --git a/lnbits/tasks.py b/lnbits/tasks.py index 7a5bf7184..74b32e905 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -1,14 +1,13 @@ import trio # type: ignore from http import HTTPStatus -from typing import Optional, Tuple, List, Callable, Awaitable +from typing import Optional, List, Callable from quart import Request, g from quart_trio import QuartTrio from werkzeug.datastructures import Headers -from lnbits.db import open_db, open_ext_db +from lnbits.db import open_db from lnbits.settings import WALLET -from lnbits.core.models import Payment from lnbits.core.crud import get_standalone_payment main_app: Optional[QuartTrio] = None @@ -19,6 +18,21 @@ def grab_app_for_later(app: QuartTrio): main_app = app +deferred_async: List[Callable] = [] + + +def record_async(func: Callable) -> Callable: + def recorder(state): + deferred_async.append(func) + + return recorder + + +def run_deferred_async(nursery): + for func in deferred_async: + nursery.start_soon(func) + + async def send_push_promise(a, b) -> None: pass @@ -45,16 +59,16 @@ async def run_on_pseudo_request(func: Callable, *args): nursery.start_soon(run) -invoice_listeners: List[Tuple[str, Callable[[Payment], Awaitable[None]]]] = [] +invoice_listeners: List[trio.MemorySendChannel] = [] -def register_invoice_listener(ext_name: str, cb: Callable[[Payment], Awaitable[None]]): +def register_invoice_listener(send_chan: trio.MemorySendChannel): """ A method intended for extensions to call when they want to be notified about new invoice payments incoming. """ - print(f"registering {ext_name} invoice_listener callback: {cb}") - invoice_listeners.append((ext_name, cb)) + print(f"registering invoice_listener: {send_chan}") + invoice_listeners.append(send_chan) async def webhook_handler(): @@ -73,9 +87,5 @@ async def invoice_callback_dispatcher(checking_id: str): payment = get_standalone_payment(checking_id) if payment and payment.is_in: payment.set_pending(False) - for ext_name, cb in invoice_listeners: - if ext_name == "core": - await cb(payment) - else: - with open_ext_db(ext_name) as g.ext_db: # type: ignore - await cb(payment) + for send_chan in invoice_listeners: + await send_chan.send(payment)