Merge pull request #1146 from lnbits/universalwebsocket

Adds universal websocket manager any extension can use
This commit is contained in:
Arc
2022-12-01 14:55:40 +00:00
committed by GitHub
3 changed files with 75 additions and 9 deletions

View File

@@ -166,7 +166,7 @@ def lnencode(addr, privkey):
if addr.amount: if addr.amount:
amount = Decimal(str(addr.amount)) amount = Decimal(str(addr.amount))
# We can only send down to millisatoshi. # We can only send down to millisatoshi.
if amount * 10**12 % 10: if amount * 10 ** 12 % 10:
raise ValueError( raise ValueError(
"Cannot encode {}: too many decimal places".format(addr.amount) "Cannot encode {}: too many decimal places".format(addr.amount)
) )
@@ -271,7 +271,7 @@ class LnAddr(object):
def shorten_amount(amount): def shorten_amount(amount):
"""Given an amount in bitcoin, shorten it""" """Given an amount in bitcoin, shorten it"""
# Convert to pico initially # Convert to pico initially
amount = int(amount * 10**12) amount = int(amount * 10 ** 12)
units = ["p", "n", "u", "m", ""] units = ["p", "n", "u", "m", ""]
for unit in units: for unit in units:
if amount % 1000 == 0: if amount % 1000 == 0:
@@ -290,7 +290,7 @@ def _unshorten_amount(amount: str) -> int:
# * `u` (micro): multiply by 0.000001 # * `u` (micro): multiply by 0.000001
# * `n` (nano): multiply by 0.000000001 # * `n` (nano): multiply by 0.000000001
# * `p` (pico): multiply by 0.000000000001 # * `p` (pico): multiply by 0.000000000001
units = {"p": 10**12, "n": 10**9, "u": 10**6, "m": 10**3} units = {"p": 10 ** 12, "n": 10 ** 9, "u": 10 ** 6, "m": 10 ** 3}
unit = str(amount)[-1] unit = str(amount)[-1]
# BOLT #11: # BOLT #11:

View File

@@ -2,11 +2,11 @@ import asyncio
import json import json
from binascii import unhexlify from binascii import unhexlify
from io import BytesIO from io import BytesIO
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import httpx import httpx
from fastapi import Depends from fastapi import Depends, WebSocket, WebSocketDisconnect
from lnurl import LnurlErrorResponse from lnurl import LnurlErrorResponse
from lnurl import decode as decode_lnurl # type: ignore from lnurl import decode as decode_lnurl # type: ignore
from loguru import logger from loguru import logger
@@ -329,12 +329,12 @@ async def perform_lnurlauth(
sign_len = 6 + r_len + s_len sign_len = 6 + r_len + s_len
signature = BytesIO() signature = BytesIO()
signature.write(0x30.to_bytes(1, "big", signed=False)) signature.write(0x30 .to_bytes(1, "big", signed=False))
signature.write((sign_len - 2).to_bytes(1, "big", signed=False)) signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(0x02 .to_bytes(1, "big", signed=False))
signature.write(r_len.to_bytes(1, "big", signed=False)) signature.write(r_len.to_bytes(1, "big", signed=False))
signature.write(r) signature.write(r)
signature.write(0x02.to_bytes(1, "big", signed=False)) signature.write(0x02 .to_bytes(1, "big", signed=False))
signature.write(s_len.to_bytes(1, "big", signed=False)) signature.write(s_len.to_bytes(1, "big", signed=False))
signature.write(s) signature.write(s)
@@ -382,3 +382,28 @@ async def check_transaction_status(
# WARN: this same value must be used for balance check and passed to WALLET.pay_invoice(), it may cause a vulnerability if the values differ # WARN: this same value must be used for balance check and passed to WALLET.pay_invoice(), it may cause a vulnerability if the values differ
def fee_reserve(amount_msat: int) -> int: def fee_reserve(amount_msat: int) -> int:
return max(int(RESERVE_FEE_MIN), int(amount_msat * RESERVE_FEE_PERCENT / 100.0)) return max(int(RESERVE_FEE_MIN), int(amount_msat * RESERVE_FEE_PERCENT / 100.0))
class WebsocketConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
logger.debug(websocket)
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_data(self, message: str, item_id: str):
for connection in self.active_connections:
if connection.path_params["item_id"] == item_id:
await connection.send_text(message)
websocketManager = WebsocketConnectionManager()
async def websocketUpdater(item_id, data):
return await websocketManager.send_data(f"{data}", item_id)

View File

@@ -12,7 +12,15 @@ from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
import async_timeout import async_timeout
import httpx import httpx
import pyqrcode import pyqrcode
from fastapi import Depends, Header, Query, Request, Response from fastapi import (
Depends,
Header,
Query,
Request,
Response,
WebSocket,
WebSocketDisconnect,
)
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.params import Body from fastapi.params import Body
from loguru import logger from loguru import logger
@@ -56,6 +64,8 @@ from ..services import (
create_invoice, create_invoice,
pay_invoice, pay_invoice,
perform_lnurlauth, perform_lnurlauth,
websocketManager,
websocketUpdater,
) )
from ..tasks import api_invoice_listeners from ..tasks import api_invoice_listeners
@@ -697,3 +707,34 @@ async def api_auditor(wallet: WalletTypeInfo = Depends(get_key_type)):
"delta_msats": delta, "delta_msats": delta,
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
##################UNIVERSAL WEBSOCKET MANAGER########################
@core_app.websocket("/api/v1/ws/{item_id}")
async def websocket_connect(websocket: WebSocket, item_id: str):
await websocketManager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
except WebSocketDisconnect:
websocketManager.disconnect(websocket)
@core_app.post("/api/v1/ws/{item_id}")
async def websocket_update_post(item_id: str, data: str):
try:
await websocketUpdater(item_id, data)
return {"sent": True, "data": data}
except:
return {"sent": False, "data": data}
@core_app.get("/api/v1/ws/{item_id}/{data}")
async def websocket_update_get(item_id: str, data: str):
try:
await websocketUpdater(item_id, data)
return {"sent": True, "data": data}
except:
return {"sent": False, "data": data}