[REFACTOR] payments sse endpoint (#1781)

* exclude sse from gzip

* refactor sse endpoint

* cleanup imports
This commit is contained in:
jackstar12
2023-08-18 12:05:14 +02:00
committed by GitHub
parent d4de78f1e8
commit 8a6e411a0d
4 changed files with 24 additions and 24 deletions

View File

@@ -14,7 +14,6 @@ from typing import Callable, List
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from loguru import logger from loguru import logger
from slowapi import Limiter from slowapi import Limiter
@@ -45,6 +44,7 @@ from .core.views.generic import core_html_routes
from .extension_manager import Extension, InstallableExtension, get_valid_extensions from .extension_manager import Extension, InstallableExtension, get_valid_extensions
from .helpers import template_renderer from .helpers import template_renderer
from .middleware import ( from .middleware import (
CustomGZipMiddleware,
ExtensionsRedirectMiddleware, ExtensionsRedirectMiddleware,
InstalledExtensionMiddleware, InstalledExtensionMiddleware,
add_ip_block_middleware, add_ip_block_middleware,
@@ -87,7 +87,9 @@ def create_app() -> FastAPI:
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
) )
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(
CustomGZipMiddleware, minimum_size=1000, exclude_paths=["/api/v1/payments/sse"]
)
# order of these two middlewares is important # order of these two middlewares is important
app.add_middleware(InstalledExtensionMiddleware) app.add_middleware(InstalledExtensionMiddleware)

View File

@@ -170,6 +170,8 @@ class Payment(FromRowModel):
async def set_pending(self, pending: bool) -> None: async def set_pending(self, pending: bool) -> None:
from .crud import update_payment_status from .crud import update_payment_status
self.pending = pending
await update_payment_status(self.checking_id, pending) await update_payment_status(self.checking_id, pending)
async def check_status( async def check_status(

View File

@@ -4,10 +4,9 @@ import json
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
import async_timeout
import httpx import httpx
import pyqrcode import pyqrcode
from fastapi import ( from fastapi import (
@@ -421,34 +420,18 @@ async def subscribe_wallet_invoices(request: Request, wallet: Wallet):
logger.debug(f"adding sse listener for wallet: {uid}") logger.debug(f"adding sse listener for wallet: {uid}")
api_invoice_listeners[uid] = payment_queue api_invoice_listeners[uid] = payment_queue
send_queue: asyncio.Queue[Tuple[str, Payment]] = asyncio.Queue(0)
async def payment_received() -> None:
while True:
try:
async with async_timeout.timeout(1):
payment: Payment = await payment_queue.get()
if payment.wallet_id == this_wallet_id:
logger.debug("sse listener: payment received", payment)
await send_queue.put(("payment-received", payment))
except asyncio.TimeoutError:
pass
task = asyncio.create_task(payment_received())
try: try:
while True: while True:
if await request.is_disconnected(): if await request.is_disconnected():
await request.close() await request.close()
break break
typ, data = await send_queue.get() payment: Payment = await payment_queue.get()
if data: if payment.wallet_id == this_wallet_id:
jdata = json.dumps(dict(data.dict(), pending=False)) logger.debug("sse listener: payment received", payment)
yield dict(data=jdata, event=typ) yield dict(data=payment.json(), event="payment-received")
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug(f"removing listener for wallet {uid}") logger.debug(f"removing listener for wallet {uid}")
api_invoice_listeners.pop(uid) api_invoice_listeners.pop(uid)
task.cancel()
return return

View File

@@ -7,6 +7,7 @@ from fastapi.responses import HTMLResponse, JSONResponse
from slowapi import _rate_limit_exceeded_handler from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware from slowapi.middleware import SlowAPIMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.core import core_app_extra from lnbits.core import core_app_extra
@@ -115,6 +116,18 @@ class InstalledExtensionMiddleware:
) )
class CustomGZipMiddleware(GZipMiddleware):
def __init__(self, *args, exclude_paths=None, **kwargs):
super().__init__(*args, **kwargs)
self.exclude_paths = exclude_paths or []
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" in scope and scope["path"] in self.exclude_paths:
await self.app(scope, receive, send)
return
await super().__call__(scope, receive, send)
class ExtensionsRedirectMiddleware: class ExtensionsRedirectMiddleware:
# Extensions are allowed to specify redirect paths. # Extensions are allowed to specify redirect paths.
# A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints. # A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.