diff --git a/lnbits/app.py b/lnbits/app.py index 4f5d7a005..69a184214 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -14,7 +14,6 @@ from typing import Callable, List from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.middleware.gzip import GZipMiddleware from fastapi.staticfiles import StaticFiles from loguru import logger from slowapi import Limiter @@ -45,6 +44,7 @@ from .core.views.generic import core_html_routes from .extension_manager import Extension, InstallableExtension, get_valid_extensions from .helpers import template_renderer from .middleware import ( + CustomGZipMiddleware, ExtensionsRedirectMiddleware, InstalledExtensionMiddleware, add_ip_block_middleware, @@ -87,7 +87,9 @@ def create_app() -> FastAPI: CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] ) - app.add_middleware(GZipMiddleware, minimum_size=1000) + app.add_middleware( + CustomGZipMiddleware, minimum_size=1000, exclude_paths=["/api/v1/payments/sse"] + ) # order of these two middlewares is important app.add_middleware(InstalledExtensionMiddleware) diff --git a/lnbits/core/models.py b/lnbits/core/models.py index 01e485c9e..c9fb56a36 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -170,6 +170,8 @@ class Payment(FromRowModel): async def set_pending(self, pending: bool) -> None: from .crud import update_payment_status + self.pending = pending + await update_payment_status(self.checking_id, pending) async def check_status( diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index c1a1ad2aa..958d2b164 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -4,10 +4,9 @@ import json import uuid from http import HTTPStatus from io import BytesIO -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse -import async_timeout import httpx import pyqrcode from fastapi import ( @@ -421,34 +420,18 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet): logger.debug(f"adding sse listener for wallet: {uid}") api_invoice_listeners[uid] = payment_queue - send_queue: asyncio.Queue[Tuple[str, Payment]] = asyncio.Queue(0) - - async def payment_received() -> None: - while True: - try: - async with async_timeout.timeout(1): - payment: Payment = await payment_queue.get() - if payment.wallet_id == this_wallet_id: - logger.debug("sse listener: payment received", payment) - await send_queue.put(("payment-received", payment)) - except asyncio.TimeoutError: - pass - - task = asyncio.create_task(payment_received()) - try: while True: if await request.is_disconnected(): await request.close() break - typ, data = await send_queue.get() - if data: - jdata = json.dumps(dict(data.dict(), pending=False)) - yield dict(data=jdata, event=typ) + payment: Payment = await payment_queue.get() + if payment.wallet_id == this_wallet_id: + logger.debug("sse listener: payment received", payment) + yield dict(data=payment.json(), event="payment-received") except asyncio.CancelledError: logger.debug(f"removing listener for wallet {uid}") api_invoice_listeners.pop(uid) - task.cancel() return diff --git a/lnbits/middleware.py b/lnbits/middleware.py index d6d4206d2..2944702e4 100644 --- a/lnbits/middleware.py +++ b/lnbits/middleware.py @@ -7,6 +7,7 @@ from fastapi.responses import HTMLResponse, JSONResponse from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware +from starlette.middleware.gzip import GZipMiddleware from starlette.types import ASGIApp, Receive, Scope, Send from lnbits.core import core_app_extra @@ -115,6 +116,18 @@ class InstalledExtensionMiddleware: ) +class CustomGZipMiddleware(GZipMiddleware): + def __init__(self, *args, exclude_paths=None, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_paths = exclude_paths or [] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if "path" in scope and scope["path"] in self.exclude_paths: + await self.app(scope, receive, send) + return + await super().__call__(scope, receive, send) + + class ExtensionsRedirectMiddleware: # Extensions are allowed to specify redirect paths. # A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.