mirror of
https://github.com/lnbits/lnbits.git
synced 2025-03-17 21:31:55 +01:00
[REFACTOR] payments sse endpoint (#1781)
* exclude sse from gzip * refactor sse endpoint * cleanup imports
This commit is contained in:
parent
d4de78f1e8
commit
8a6e411a0d
@ -14,7 +14,6 @@ from typing import Callable, List
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
from slowapi import Limiter
|
||||
@ -45,6 +44,7 @@ from .core.views.generic import core_html_routes
|
||||
from .extension_manager import Extension, InstallableExtension, get_valid_extensions
|
||||
from .helpers import template_renderer
|
||||
from .middleware import (
|
||||
CustomGZipMiddleware,
|
||||
ExtensionsRedirectMiddleware,
|
||||
InstalledExtensionMiddleware,
|
||||
add_ip_block_middleware,
|
||||
@ -87,7 +87,9 @@ def create_app() -> FastAPI:
|
||||
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
app.add_middleware(
|
||||
CustomGZipMiddleware, minimum_size=1000, exclude_paths=["/api/v1/payments/sse"]
|
||||
)
|
||||
|
||||
# order of these two middlewares is important
|
||||
app.add_middleware(InstalledExtensionMiddleware)
|
||||
|
@ -170,6 +170,8 @@ class Payment(FromRowModel):
|
||||
async def set_pending(self, pending: bool) -> None:
|
||||
from .crud import update_payment_status
|
||||
|
||||
self.pending = pending
|
||||
|
||||
await update_payment_status(self.checking_id, pending)
|
||||
|
||||
async def check_status(
|
||||
|
@ -4,10 +4,9 @@ import json
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import async_timeout
|
||||
import httpx
|
||||
import pyqrcode
|
||||
from fastapi import (
|
||||
@ -421,34 +420,18 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet):
|
||||
logger.debug(f"adding sse listener for wallet: {uid}")
|
||||
api_invoice_listeners[uid] = payment_queue
|
||||
|
||||
send_queue: asyncio.Queue[Tuple[str, Payment]] = asyncio.Queue(0)
|
||||
|
||||
async def payment_received() -> None:
|
||||
while True:
|
||||
try:
|
||||
async with async_timeout.timeout(1):
|
||||
payment: Payment = await payment_queue.get()
|
||||
if payment.wallet_id == this_wallet_id:
|
||||
logger.debug("sse listener: payment received", payment)
|
||||
await send_queue.put(("payment-received", payment))
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(payment_received())
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
await request.close()
|
||||
break
|
||||
typ, data = await send_queue.get()
|
||||
if data:
|
||||
jdata = json.dumps(dict(data.dict(), pending=False))
|
||||
yield dict(data=jdata, event=typ)
|
||||
payment: Payment = await payment_queue.get()
|
||||
if payment.wallet_id == this_wallet_id:
|
||||
logger.debug("sse listener: payment received", payment)
|
||||
yield dict(data=payment.json(), event="payment-received")
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"removing listener for wallet {uid}")
|
||||
api_invoice_listeners.pop(uid)
|
||||
task.cancel()
|
||||
return
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from lnbits.core import core_app_extra
|
||||
@ -115,6 +116,18 @@ class InstalledExtensionMiddleware:
|
||||
)
|
||||
|
||||
|
||||
class CustomGZipMiddleware(GZipMiddleware):
|
||||
def __init__(self, *args, exclude_paths=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exclude_paths = exclude_paths or []
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if "path" in scope and scope["path"] in self.exclude_paths:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
await super().__call__(scope, receive, send)
|
||||
|
||||
|
||||
class ExtensionsRedirectMiddleware:
|
||||
# Extensions are allowed to specify redirect paths.
|
||||
# A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.
|
||||
|
Loading…
x
Reference in New Issue
Block a user