feat: add has_connection, listen and receive_queue to websocket_manager (#3330)

This commit is contained in:
dni ⚡
2025-09-02 10:02:22 +02:00
committed by GitHub
parent 0d91d5b5be
commit 5021570f68
3 changed files with 62 additions and 38 deletions

View File

@@ -277,7 +277,7 @@ async def send_payment_notification(wallet: Wallet, payment: Payment):
async def send_ws_payment_notification(wallet: Wallet, payment: Payment):
# TODO: websocket message should be a clean payment model
# await websocket_manager.send_data(payment.json(), wallet.inkey)
# await websocket_manager.send(wallet.inkey, payment.json())
# TODO: figure out why we send the balance with the payment here.
# cleaner would be to have a separate message for the balance
# and send it with the id of the wallet so wallets can subscribe to it
@@ -288,12 +288,11 @@ async def send_ws_payment_notification(wallet: Wallet, payment: Payment):
"payment": json.loads(payment.json()),
},
)
await websocket_manager.send_data(payment_notification, wallet.inkey)
await websocket_manager.send_data(payment_notification, wallet.adminkey)
await websocket_manager.send_data(
json.dumps({"pending": payment.pending, "status": payment.status}),
await websocket_manager.send(wallet.inkey, payment_notification)
await websocket_manager.send(wallet.adminkey, payment_notification)
await websocket_manager.send(
payment.payment_hash,
json.dumps({"pending": payment.pending, "status": payment.status}),
)

View File

@@ -1,27 +1,65 @@
from fastapi import WebSocket
from asyncio import Queue
from dataclasses import dataclass
from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger
from lnbits.settings import settings
@dataclass
class WebsocketConnection:
item_id: str
websocket: WebSocket
receive_queue: Queue[str]
class WebsocketConnectionManager:
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
self.active_connections: list[WebsocketConnection] = []
async def connect(self, websocket: WebSocket, item_id: str):
async def connect(self, item_id: str, websocket: WebSocket) -> WebsocketConnection:
logger.debug(f"Websocket connected to {item_id}")
await websocket.accept()
self.active_connections.append(websocket)
conn = WebsocketConnection(
item_id=item_id,
websocket=websocket,
receive_queue=Queue(),
)
self.active_connections.append(conn)
return conn
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def listen(self, conn: WebsocketConnection) -> None:
while settings.lnbits_running:
try:
data = await conn.websocket.receive_text()
logger.debug(f"WS received data from {conn.item_id}: {data}")
conn.receive_queue.put_nowait(data)
except WebSocketDisconnect:
for _conn in self.active_connections:
if _conn.websocket == conn.websocket:
self.active_connections.remove(_conn)
logger.debug(f"WS disconnected from {conn.item_id}")
break # out of the listen and the fastapi route
async def send_data(self, message: str, item_id: str):
for connection in self.active_connections:
if connection.path_params["item_id"] == item_id:
await connection.send_text(message)
def get_connections(self, item_id: str) -> list[WebsocketConnection]:
conns = []
for conn in self.active_connections:
if conn.item_id == item_id:
conns.append(conn)
return conns
def has_connection(self, item_id: str) -> bool:
return len(self.get_connections(item_id)) > 0
async def send(self, item_id: str, data: str) -> None:
for conn in self.get_connections(item_id):
await conn.websocket.send_text(data)
websocket_manager = WebsocketConnectionManager()
async def websocket_updater(item_id: str, data: str):
return await websocket_manager.send_data(data, item_id)
# deprecated import and use `websocket_manager.send()` instead
async def websocket_updater(item_id: str, data: str) -> None:
return await websocket_manager.send(item_id, data)

View File

@@ -1,33 +1,20 @@
from fastapi import (
APIRouter,
WebSocket,
WebSocketDisconnect,
)
from fastapi import APIRouter, WebSocket
from lnbits.settings import settings
from ..services import (
websocket_manager,
websocket_updater,
)
from ..services import websocket_manager
websocket_router = APIRouter(prefix="/api/v1/ws", tags=["Websocket"])
@websocket_router.websocket("/{item_id}")
async def websocket_connect(websocket: WebSocket, item_id: str):
await websocket_manager.connect(websocket, item_id)
try:
while settings.lnbits_running:
await websocket.receive_text()
except WebSocketDisconnect:
websocket_manager.disconnect(websocket)
async def websocket_connect(websocket: WebSocket, item_id: str) -> None:
conn = await websocket_manager.connect(item_id, websocket)
await websocket_manager.listen(conn)
@websocket_router.post("/{item_id}")
async def websocket_update_post(item_id: str, data: str):
try:
await websocket_updater(item_id, data)
await websocket_manager.send(item_id, data)
return {"sent": True, "data": data}
except Exception:
return {"sent": False, "data": data}
@@ -36,7 +23,7 @@ async def websocket_update_post(item_id: str, data: str):
@websocket_router.get("/{item_id}/{data}")
async def websocket_update_get(item_id: str, data: str):
try:
await websocket_updater(item_id, data)
await websocket_manager.send(item_id, data)
return {"sent": True, "data": data}
except Exception:
return {"sent": False, "data": data}