diff --git a/lnbits/bolt11.py b/lnbits/bolt11.py index 32b43feb6..08f1f1e59 100644 --- a/lnbits/bolt11.py +++ b/lnbits/bolt11.py @@ -166,7 +166,7 @@ def lnencode(addr, privkey): if addr.amount: amount = Decimal(str(addr.amount)) # We can only send down to millisatoshi. - if amount * 10**12 % 10: + if amount * 10 ** 12 % 10: raise ValueError( "Cannot encode {}: too many decimal places".format(addr.amount) ) @@ -271,7 +271,7 @@ class LnAddr(object): def shorten_amount(amount): """Given an amount in bitcoin, shorten it""" # Convert to pico initially - amount = int(amount * 10**12) + amount = int(amount * 10 ** 12) units = ["p", "n", "u", "m", ""] for unit in units: if amount % 1000 == 0: @@ -290,7 +290,7 @@ def _unshorten_amount(amount: str) -> int: # * `u` (micro): multiply by 0.000001 # * `n` (nano): multiply by 0.000000001 # * `p` (pico): multiply by 0.000000000001 - units = {"p": 10**12, "n": 10**9, "u": 10**6, "m": 10**3} + units = {"p": 10 ** 12, "n": 10 ** 9, "u": 10 ** 6, "m": 10 ** 3} unit = str(amount)[-1] # BOLT #11: diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 5d993b4c5..623f78139 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -2,11 +2,11 @@ import asyncio import json from binascii import unhexlify from io import BytesIO -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlparse import httpx -from fastapi import Depends +from fastapi import Depends, WebSocket, WebSocketDisconnect from lnurl import LnurlErrorResponse from lnurl import decode as decode_lnurl # type: ignore from loguru import logger @@ -329,12 +329,12 @@ async def perform_lnurlauth( sign_len = 6 + r_len + s_len signature = BytesIO() - signature.write(0x30.to_bytes(1, "big", signed=False)) + signature.write(0x30 .to_bytes(1, "big", signed=False)) signature.write((sign_len - 2).to_bytes(1, "big", signed=False)) - signature.write(0x02.to_bytes(1, "big", signed=False)) + signature.write(0x02 .to_bytes(1, "big", signed=False)) signature.write(r_len.to_bytes(1, "big", signed=False)) signature.write(r) - signature.write(0x02.to_bytes(1, "big", signed=False)) + signature.write(0x02 .to_bytes(1, "big", signed=False)) signature.write(s_len.to_bytes(1, "big", signed=False)) signature.write(s) @@ -382,3 +382,28 @@ async def check_transaction_status( # WARN: this same value must be used for balance check and passed to WALLET.pay_invoice(), it may cause a vulnerability if the values differ def fee_reserve(amount_msat: int) -> int: return max(int(RESERVE_FEE_MIN), int(amount_msat * RESERVE_FEE_PERCENT / 100.0)) + + +class WebsocketConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + + async def connect(self, websocket: WebSocket): + await websocket.accept() + logger.debug(websocket) + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + + async def send_data(self, message: str, item_id: str): + for connection in self.active_connections: + if connection.path_params["item_id"] == item_id: + await connection.send_text(message) + + +websocketManager = WebsocketConnectionManager() + + +async def websocketUpdater(item_id, data): + return await websocketManager.send_data(f"{data}", item_id) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index ae3e6a5ef..f78219bf1 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -12,7 +12,15 @@ from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse import async_timeout import httpx import pyqrcode -from fastapi import Depends, Header, Query, Request, Response +from fastapi import ( + Depends, + Header, + Query, + Request, + Response, + WebSocket, + WebSocketDisconnect, +) from fastapi.exceptions import HTTPException from fastapi.params import Body from loguru import logger @@ -56,6 +64,8 @@ from ..services import ( create_invoice, pay_invoice, perform_lnurlauth, + websocketManager, + websocketUpdater, ) from ..tasks import api_invoice_listeners @@ -697,3 +707,34 @@ async def api_auditor(wallet: WalletTypeInfo = Depends(get_key_type)): "delta_msats": delta, "timestamp": int(time.time()), } + + +##################UNIVERSAL WEBSOCKET MANAGER######################## + + +@core_app.websocket("/api/v1/ws/{item_id}") +async def websocket_connect(websocket: WebSocket, item_id: str): + await websocketManager.connect(websocket) + try: + while True: + data = await websocket.receive_text() + except WebSocketDisconnect: + websocketManager.disconnect(websocket) + + +@core_app.post("/api/v1/ws/{item_id}") +async def websocket_update_post(item_id: str, data: str): + try: + await websocketUpdater(item_id, data) + return {"sent": True, "data": data} + except: + return {"sent": False, "data": data} + + +@core_app.get("/api/v1/ws/{item_id}/{data}") +async def websocket_update_get(item_id: str, data: str): + try: + await websocketUpdater(item_id, data) + return {"sent": True, "data": data} + except: + return {"sent": False, "data": data}