From e6499104c047fb848136390cf62ec6abbab92f13 Mon Sep 17 00:00:00 2001 From: jackstar12 <62219658+jackstar12@users.noreply.github.com> Date: Mon, 19 Jun 2023 12:12:00 +0200 Subject: [PATCH] Wallets refactor (#1729) * feat: cleanup function for wallet * update eclair implementation * update lnd implementation * update lnbits implementation * update lnpay implementation * update lnbits implementation * update opennode implementation * update spark implementation * use base_url for clients * fix lnpay * fix opennode * fix lntips * test real invoice creation * add small delay to test * test paid invoice stream * fix lnbits * fix lndrest * fix spark fix spark * check node balance in test * increase balance check delay * check balance in pay test aswell * make sure get_payment_status is called * fix lndrest * revert unnecessary changes --- lnbits/app.py | 11 +++- lnbits/wallets/base.py | 3 + lnbits/wallets/eclair.py | 61 ++++++++----------- lnbits/wallets/lnbits.py | 55 +++++++---------- lnbits/wallets/lndrest.py | 113 ++++++++++++++++------------------- lnbits/wallets/lnpay.py | 52 +++++++--------- lnbits/wallets/lntips.py | 78 +++++++++++------------- lnbits/wallets/opennode.py | 51 +++++++--------- lnbits/wallets/spark.py | 37 ++++++------ tests/core/views/test_api.py | 55 +++++++++++++++-- tests/helpers.py | 5 +- 11 files changed, 259 insertions(+), 262 deletions(-) diff --git a/lnbits/app.py b/lnbits/app.py index 050b78608..d3dfa4407 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -46,7 +46,6 @@ from .tasks import ( def create_app() -> FastAPI: - configure_logger() app = FastAPI( @@ -82,6 +81,7 @@ def create_app() -> FastAPI: register_routes(app) register_async_tasks(app) register_exception_handlers(app) + register_shutdown(app) # Allow registering new extensions routes without direct access to the `app` object setattr(core_app_extra, "register_new_ext_routes", register_new_ext_routes(app)) @@ -90,7 +90,6 @@ def create_app() -> FastAPI: async def check_funding_source() -> None: - original_sigint_handler = signal.getsignal(signal.SIGINT) def signal_handler(signal, frame): @@ -279,7 +278,6 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None: def register_startup(app: FastAPI): @app.on_event("startup") async def lnbits_startup(): - try: # wait till migration is done await migrate_databases() @@ -303,6 +301,13 @@ def register_startup(app: FastAPI): raise ImportError("Failed to run 'startup' event.") +def register_shutdown(app: FastAPI): + @app.on_event("shutdown") + async def on_shutdown(): + WALLET = get_wallet_class() + await WALLET.cleanup() + + def log_server_info(): logger.info("Starting LNbits") logger.info(f"Version: {settings.version}") diff --git a/lnbits/wallets/base.py b/lnbits/wallets/base.py index 68e49c9d0..035424531 100644 --- a/lnbits/wallets/base.py +++ b/lnbits/wallets/base.py @@ -48,6 +48,9 @@ class PaymentStatus(NamedTuple): class Wallet(ABC): + async def cleanup(self): + pass + @abstractmethod def status(self) -> Coroutine[None, None, StatusResponse]: pass diff --git a/lnbits/wallets/eclair.py b/lnbits/wallets/eclair.py index 09dea1f67..da40192f7 100644 --- a/lnbits/wallets/eclair.py +++ b/lnbits/wallets/eclair.py @@ -41,12 +41,13 @@ class EclairWallet(Wallet): encodedAuth = base64.b64encode(f":{passw}".encode()) auth = str(encodedAuth, "utf-8") self.auth = {"Authorization": f"Basic {auth}"} + self.client = httpx.AsyncClient(base_url=self.url, headers=self.auth) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/globalbalance", headers=self.auth, timeout=5 - ) + r = await self.client.post("/globalbalance", timeout=5) try: data = r.json() except: @@ -69,7 +70,6 @@ class EclairWallet(Wallet): unhashed_description: Optional[bytes] = None, **kwargs, ) -> InvoiceResponse: - data: Dict[str, Any] = { "amountMsat": amount * 1000, } @@ -84,10 +84,7 @@ class EclairWallet(Wallet): else: data["description"] = memo - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/createinvoice", headers=self.auth, data=data, timeout=40 - ) + r = await self.client.post("/createinvoice", data=data, timeout=40) if r.is_error: try: @@ -102,13 +99,11 @@ class EclairWallet(Wallet): return InvoiceResponse(True, data["paymentHash"], data["serialized"], None) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/payinvoice", - headers=self.auth, - data={"invoice": bolt11, "blocking": True}, - timeout=None, - ) + r = await self.client.post( + "/payinvoice", + data={"invoice": bolt11, "blocking": True}, + timeout=None, + ) if "error" in r.json(): try: @@ -128,13 +123,11 @@ class EclairWallet(Wallet): # We do all this again to get the fee: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/getsentinfo", - headers=self.auth, - data={"paymentHash": checking_id}, - timeout=40, - ) + r = await self.client.post( + "/getsentinfo", + data={"paymentHash": checking_id}, + timeout=40, + ) if "error" in r.json(): try: @@ -162,12 +155,10 @@ class EclairWallet(Wallet): async def get_invoice_status(self, checking_id: str) -> PaymentStatus: try: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/getreceivedinfo", - headers=self.auth, - data={"paymentHash": checking_id}, - ) + r = await self.client.post( + "/getreceivedinfo", + data={"paymentHash": checking_id}, + ) r.raise_for_status() data = r.json() @@ -186,13 +177,11 @@ class EclairWallet(Wallet): async def get_payment_status(self, checking_id: str) -> PaymentStatus: try: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.url}/getsentinfo", - headers=self.auth, - data={"paymentHash": checking_id}, - timeout=40, - ) + r = await self.client.post( + "/getsentinfo", + data={"paymentHash": checking_id}, + timeout=40, + ) r.raise_for_status() diff --git a/lnbits/wallets/lnbits.py b/lnbits/wallets/lnbits.py index 902711d67..f91ef6deb 100644 --- a/lnbits/wallets/lnbits.py +++ b/lnbits/wallets/lnbits.py @@ -29,17 +29,18 @@ class LNbitsWallet(Wallet): if not self.endpoint or not key: raise Exception("cannot initialize lnbits wallet") self.key = {"X-Api-Key": key} + self.client = httpx.AsyncClient(base_url=self.endpoint, headers=self.key) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: - async with httpx.AsyncClient() as client: - try: - r = await client.get( - url=f"{self.endpoint}/api/v1/wallet", headers=self.key, timeout=15 - ) - except Exception as exc: - return StatusResponse( - f"Failed to connect to {self.endpoint} due to: {exc}", 0 - ) + try: + r = await self.client.get(url="/api/v1/wallet", timeout=15) + except Exception as exc: + return StatusResponse( + f"Failed to connect to {self.endpoint} due to: {exc}", 0 + ) try: data = r.json() @@ -69,10 +70,7 @@ class LNbitsWallet(Wallet): if unhashed_description: data["unhashed_description"] = unhashed_description.hex() - async with httpx.AsyncClient() as client: - r = await client.post( - url=f"{self.endpoint}/api/v1/payments", headers=self.key, json=data - ) + r = await self.client.post(url="/api/v1/payments", json=data) ok, checking_id, payment_request, error_message = ( not r.is_error, None, @@ -89,20 +87,12 @@ class LNbitsWallet(Wallet): return InvoiceResponse(ok, checking_id, payment_request, error_message) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - url=f"{self.endpoint}/api/v1/payments", - headers=self.key, - json={"out": True, "bolt11": bolt11}, - timeout=None, - ) - ok, checking_id, _, _, error_message = ( - not r.is_error, - None, - None, - None, - None, + r = await self.client.post( + url="/api/v1/payments", + json={"out": True, "bolt11": bolt11}, + timeout=None, ) + ok = not r.is_error if r.is_error: error_message = r.json()["detail"] @@ -118,11 +108,9 @@ class LNbitsWallet(Wallet): async def get_invoice_status(self, checking_id: str) -> PaymentStatus: try: - async with httpx.AsyncClient() as client: - r = await client.get( - url=f"{self.endpoint}/api/v1/payments/{checking_id}", - headers=self.key, - ) + r = await self.client.get( + url=f"/api/v1/payments/{checking_id}", + ) if r.is_error: return PaymentStatus(None) return PaymentStatus(r.json()["paid"]) @@ -130,10 +118,7 @@ class LNbitsWallet(Wallet): return PaymentStatus(None) async def get_payment_status(self, checking_id: str) -> PaymentStatus: - async with httpx.AsyncClient() as client: - r = await client.get( - url=f"{self.endpoint}/api/v1/payments/{checking_id}", headers=self.key - ) + r = await self.client.get(url=f"/api/v1/payments/{checking_id}") if r.is_error: return PaymentStatus(None) diff --git a/lnbits/wallets/lndrest.py b/lnbits/wallets/lndrest.py index a8a94dab0..99ceaf540 100644 --- a/lnbits/wallets/lndrest.py +++ b/lnbits/wallets/lndrest.py @@ -64,14 +64,17 @@ class LndRestWallet(Wallet): self.cert = cert or True self.auth = {"Grpc-Metadata-macaroon": self.macaroon} + self.client = httpx.AsyncClient( + base_url=self.endpoint, headers=self.auth, verify=self.cert + ) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: try: - async with httpx.AsyncClient(verify=self.cert) as client: - r = await client.get( - f"{self.endpoint}/v1/balance/channels", headers=self.auth - ) - r.raise_for_status() + r = await self.client.get("/v1/balance/channels") + r.raise_for_status() except (httpx.ConnectError, httpx.RequestError) as exc: return StatusResponse(f"Unable to connect to {self.endpoint}. {exc}", 0) @@ -104,10 +107,7 @@ class LndRestWallet(Wallet): hashlib.sha256(unhashed_description).digest() ).decode("ascii") - async with httpx.AsyncClient(verify=self.cert) as client: - r = await client.post( - url=f"{self.endpoint}/v1/invoices", headers=self.auth, json=data - ) + r = await self.client.post(url="/v1/invoices", json=data) if r.is_error: error_message = r.text @@ -125,17 +125,15 @@ class LndRestWallet(Wallet): return InvoiceResponse(True, checking_id, payment_request, None) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient(verify=self.cert) as client: - # set the fee limit for the payment - lnrpcFeeLimit = dict() - lnrpcFeeLimit["fixed_msat"] = f"{fee_limit_msat}" + # set the fee limit for the payment + lnrpcFeeLimit = dict() + lnrpcFeeLimit["fixed_msat"] = f"{fee_limit_msat}" - r = await client.post( - url=f"{self.endpoint}/v1/channels/transactions", - headers=self.auth, - json={"payment_request": bolt11, "fee_limit": lnrpcFeeLimit}, - timeout=None, - ) + r = await self.client.post( + url="/v1/channels/transactions", + json={"payment_request": bolt11, "fee_limit": lnrpcFeeLimit}, + timeout=None, + ) if r.is_error or r.json().get("payment_error"): error_message = r.json().get("payment_error") or r.text @@ -148,10 +146,7 @@ class LndRestWallet(Wallet): return PaymentResponse(True, checking_id, fee_msat, preimage, None) async def get_invoice_status(self, checking_id: str) -> PaymentStatus: - async with httpx.AsyncClient(verify=self.cert) as client: - r = await client.get( - url=f"{self.endpoint}/v1/invoice/{checking_id}", headers=self.auth - ) + r = await self.client.get(url=f"/v1/invoice/{checking_id}") if r.is_error or not r.json().get("settled"): # this must also work when checking_id is not a hex recognizable by lnd @@ -172,7 +167,7 @@ class LndRestWallet(Wallet): except ValueError: return PaymentStatus(None) - url = f"{self.endpoint}/v2/router/track/{checking_id}" + url = f"/v2/router/track/{checking_id}" # check payment.status: # https://api.lightning.community/?python=#paymentpaymentstatus @@ -183,52 +178,46 @@ class LndRestWallet(Wallet): "FAILED": False, } - async with httpx.AsyncClient( - timeout=None, headers=self.auth, verify=self.cert - ) as client: - async with client.stream("GET", url) as r: - async for json_line in r.aiter_lines(): - try: - line = json.loads(json_line) - if line.get("error"): - logger.error( - line["error"]["message"] - if "message" in line["error"] - else line["error"] - ) - return PaymentStatus(None) - payment = line.get("result") - if payment is not None and payment.get("status"): - return PaymentStatus( - paid=statuses[payment["status"]], - fee_msat=payment.get("fee_msat"), - preimage=payment.get("payment_preimage"), - ) - else: - return PaymentStatus(None) - except: - continue + async with self.client.stream("GET", url, timeout=None) as r: + async for json_line in r.aiter_lines(): + try: + line = json.loads(json_line) + if line.get("error"): + logger.error( + line["error"]["message"] + if "message" in line["error"] + else line["error"] + ) + return PaymentStatus(None) + payment = line.get("result") + if payment is not None and payment.get("status"): + return PaymentStatus( + paid=statuses[payment["status"]], + fee_msat=payment.get("fee_msat"), + preimage=payment.get("payment_preimage"), + ) + else: + return PaymentStatus(None) + except: + continue return PaymentStatus(None) async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: while True: try: - url = self.endpoint + "/v1/invoices/subscribe" - async with httpx.AsyncClient( - timeout=None, headers=self.auth, verify=self.cert - ) as client: - async with client.stream("GET", url) as r: - async for line in r.aiter_lines(): - try: - inv = json.loads(line)["result"] - if not inv["settled"]: - continue - except: + url = "/v1/invoices/subscribe" + async with self.client.stream("GET", url, timeout=None) as r: + async for line in r.aiter_lines(): + try: + inv = json.loads(line)["result"] + if not inv["settled"]: continue + except: + continue - payment_hash = base64.b64decode(inv["r_hash"]).hex() - yield payment_hash + payment_hash = base64.b64decode(inv["r_hash"]).hex() + yield payment_hash except Exception as exc: logger.error( f"lost connection to lnd invoices stream: '{exc}', retrying in 5 seconds" diff --git a/lnbits/wallets/lnpay.py b/lnbits/wallets/lnpay.py index f05e44320..af60e10be 100644 --- a/lnbits/wallets/lnpay.py +++ b/lnbits/wallets/lnpay.py @@ -32,12 +32,15 @@ class LNPayWallet(Wallet): self.wallet_key = wallet_key self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.auth = {"X-Api-Key": settings.lnpay_api_key} + self.client = httpx.AsyncClient(base_url=self.endpoint, headers=self.auth) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: - url = f"{self.endpoint}/wallet/{self.wallet_key}" + url = f"/wallet/{self.wallet_key}" try: - async with httpx.AsyncClient() as client: - r = await client.get(url, headers=self.auth, timeout=60) + r = await self.client.get(url, timeout=60) except (httpx.ConnectError, httpx.RequestError): return StatusResponse(f"Unable to connect to '{url}'", 0) @@ -69,13 +72,11 @@ class LNPayWallet(Wallet): else: data["memo"] = memo or "" - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/wallet/{self.wallet_key}/invoice", - headers=self.auth, - json=data, - timeout=60, - ) + r = await self.client.post( + f"/wallet/{self.wallet_key}/invoice", + json=data, + timeout=60, + ) ok, checking_id, payment_request, error_message = ( r.status_code == 201, None, @@ -90,13 +91,11 @@ class LNPayWallet(Wallet): return InvoiceResponse(ok, checking_id, payment_request, error_message) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/wallet/{self.wallet_key}/withdraw", - headers=self.auth, - json={"payment_request": bolt11}, - timeout=None, - ) + r = await self.client.post( + f"/wallet/{self.wallet_key}/withdraw", + json={"payment_request": bolt11}, + timeout=None, + ) try: data = r.json() @@ -117,11 +116,9 @@ class LNPayWallet(Wallet): return await self.get_payment_status(checking_id) async def get_payment_status(self, checking_id: str) -> PaymentStatus: - async with httpx.AsyncClient() as client: - r = await client.get( - url=f"{self.endpoint}/lntx/{checking_id}", - headers=self.auth, - ) + r = await self.client.get( + url=f"/lntx/{checking_id}", + ) if r.is_error: return PaymentStatus(None) @@ -155,12 +152,9 @@ class LNPayWallet(Wallet): raise HTTPException(status_code=HTTPStatus.NO_CONTENT) lntx_id = data["data"]["wtx"]["lnTx"]["id"] - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.endpoint}/lntx/{lntx_id}?fields=settled", headers=self.auth - ) - data = r.json() - if data["settled"]: - await self.queue.put(lntx_id) + r = await self.client.get(f"/lntx/{lntx_id}?fields=settled") + data = r.json() + if data["settled"]: + await self.queue.put(lntx_id) raise HTTPException(status_code=HTTPStatus.NO_CONTENT) diff --git a/lnbits/wallets/lntips.py b/lnbits/wallets/lntips.py index be8159b9c..c8221b6bf 100644 --- a/lnbits/wallets/lntips.py +++ b/lnbits/wallets/lntips.py @@ -30,12 +30,13 @@ class LnTipsWallet(Wallet): raise Exception("cannot initialize lntxbod") self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.auth = {"Authorization": f"Basic {key}"} + self.client = httpx.AsyncClient(base_url=self.endpoint, headers=self.auth) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.endpoint}/api/v1/balance", headers=self.auth, timeout=40 - ) + r = await self.client.get("/api/v1/balance", timeout=40) try: data = r.json() except: @@ -62,13 +63,11 @@ class LnTipsWallet(Wallet): elif unhashed_description: data["description_hash"] = hashlib.sha256(unhashed_description).hexdigest() - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/api/v1/createinvoice", - headers=self.auth, - json=data, - timeout=40, - ) + r = await self.client.post( + "/api/v1/createinvoice", + json=data, + timeout=40, + ) if r.is_error: try: @@ -85,13 +84,11 @@ class LnTipsWallet(Wallet): ) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/api/v1/payinvoice", - headers=self.auth, - json={"pay_req": bolt11}, - timeout=None, - ) + r = await self.client.post( + "/api/v1/payinvoice", + json={"pay_req": bolt11}, + timeout=None, + ) if r.is_error: return PaymentResponse(False, None, 0, None, r.text) @@ -111,11 +108,9 @@ class LnTipsWallet(Wallet): async def get_invoice_status(self, checking_id: str) -> PaymentStatus: try: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/api/v1/invoicestatus/{checking_id}", - headers=self.auth, - ) + r = await self.client.post( + f"/api/v1/invoicestatus/{checking_id}", + ) if r.is_error or len(r.text) == 0: raise Exception @@ -127,11 +122,9 @@ class LnTipsWallet(Wallet): async def get_payment_status(self, checking_id: str) -> PaymentStatus: try: - async with httpx.AsyncClient() as client: - r = await client.post( - url=f"{self.endpoint}/api/v1/paymentstatus/{checking_id}", - headers=self.auth, - ) + r = await self.client.post( + url=f"/api/v1/paymentstatus/{checking_id}", + ) if r.is_error: raise Exception @@ -145,23 +138,22 @@ class LnTipsWallet(Wallet): async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: last_connected = None while True: - url = f"{self.endpoint}/api/v1/invoicestream" + url = "/api/v1/invoicestream" try: - async with httpx.AsyncClient(timeout=None, headers=self.auth) as client: - last_connected = time.time() - async with client.stream("GET", url) as r: - async for line in r.aiter_lines(): - try: - prefix = "data: " - if not line.startswith(prefix): - continue - data = line[len(prefix) :] # sse parsing - inv = json.loads(data) - if not inv.get("payment_hash"): - continue - except: + last_connected = time.time() + async with self.client.stream("GET", url) as r: + async for line in r.aiter_lines(): + try: + prefix = "data: " + if not line.startswith(prefix): continue - yield inv["payment_hash"] + data = line[len(prefix) :] # sse parsing + inv = json.loads(data) + if not inv.get("payment_hash"): + continue + except: + continue + yield inv["payment_hash"] except Exception: pass diff --git a/lnbits/wallets/opennode.py b/lnbits/wallets/opennode.py index 08b234ee4..c75b57abc 100644 --- a/lnbits/wallets/opennode.py +++ b/lnbits/wallets/opennode.py @@ -34,13 +34,14 @@ class OpenNodeWallet(Wallet): self.endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint self.auth = {"Authorization": key} + self.client = httpx.AsyncClient(base_url=self.endpoint, headers=self.auth) + + async def cleanup(self): + await self.client.aclose() async def status(self) -> StatusResponse: try: - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.endpoint}/v1/account/balance", headers=self.auth, timeout=40 - ) + r = await self.client.get("/v1/account/balance", timeout=40) except (httpx.ConnectError, httpx.RequestError): return StatusResponse(f"Unable to connect to '{self.endpoint}'", 0) @@ -61,17 +62,15 @@ class OpenNodeWallet(Wallet): if description_hash or unhashed_description: raise Unsupported("description_hash") - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/v1/charges", - headers=self.auth, - json={ - "amount": amount, - "description": memo or "", - # "callback_url": url_for("/webhook_listener", _external=True), - }, - timeout=40, - ) + r = await self.client.post( + "/v1/charges", + json={ + "amount": amount, + "description": memo or "", + # "callback_url": url_for("/webhook_listener", _external=True), + }, + timeout=40, + ) if r.is_error: error_message = r.json()["message"] @@ -83,13 +82,11 @@ class OpenNodeWallet(Wallet): return InvoiceResponse(True, checking_id, payment_request, None) async def pay_invoice(self, bolt11: str, fee_limit_msat: int) -> PaymentResponse: - async with httpx.AsyncClient() as client: - r = await client.post( - f"{self.endpoint}/v2/withdrawals", - headers=self.auth, - json={"type": "ln", "address": bolt11}, - timeout=None, - ) + r = await self.client.post( + "/v2/withdrawals", + json={"type": "ln", "address": bolt11}, + timeout=None, + ) if r.is_error: error_message = r.json()["message"] @@ -105,10 +102,7 @@ class OpenNodeWallet(Wallet): return PaymentResponse(True, checking_id, fee_msat, None, None) async def get_invoice_status(self, checking_id: str) -> PaymentStatus: - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.endpoint}/v1/charge/{checking_id}", headers=self.auth - ) + r = await self.client.get(f"/v1/charge/{checking_id}") if r.is_error: return PaymentStatus(None) data = r.json()["data"] @@ -116,10 +110,7 @@ class OpenNodeWallet(Wallet): return PaymentStatus(statuses[data.get("status")]) async def get_payment_status(self, checking_id: str) -> PaymentStatus: - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self.endpoint}/v1/withdrawal/{checking_id}", headers=self.auth - ) + r = await self.client.get(f"/v1/withdrawal/{checking_id}") if r.is_error: return PaymentStatus(None) diff --git a/lnbits/wallets/spark.py b/lnbits/wallets/spark.py index 8f41e3726..4e02693c7 100644 --- a/lnbits/wallets/spark.py +++ b/lnbits/wallets/spark.py @@ -31,6 +31,13 @@ class SparkWallet(Wallet): assert settings.spark_url, "spark url does not exist" self.url = settings.spark_url.replace("/rpc", "") self.token = settings.spark_token + assert self.token, "spark wallet token does not exist" + self.client = httpx.AsyncClient( + base_url=self.url, headers={"X-Access": self.token} + ) + + async def cleanup(self): + await self.client.aclose() def __getattr__(self, key): async def call(*args, **kwargs): @@ -46,15 +53,12 @@ class SparkWallet(Wallet): params = {} try: - async with httpx.AsyncClient() as client: - assert self.token, "spark wallet token does not exist" - r = await client.post( - self.url + "/rpc", - headers={"X-Access": self.token}, - json={"method": key, "params": params}, - timeout=60 * 60 * 24, - ) - r.raise_for_status() + r = await self.client.post( + "/rpc", + json={"method": key, "params": params}, + timeout=60 * 60 * 24, + ) + r.raise_for_status() except ( OSError, httpx.ConnectError, @@ -224,17 +228,16 @@ class SparkWallet(Wallet): raise KeyError("supplied an invalid checking_id") async def paid_invoices_stream(self) -> AsyncGenerator[str, None]: - url = f"{self.url}/stream?access-key={self.token}" + url = f"/stream?access-key={self.token}" while True: try: - async with httpx.AsyncClient(timeout=None) as client: - async with client.stream("GET", url) as r: - async for line in r.aiter_lines(): - if line.startswith("data:"): - data = json.loads(line[5:]) - if "pay_index" in data and data.get("status") == "paid": - yield data["label"] + async with self.client.stream("GET", url, timeout=None) as r: + async for line in r.aiter_lines(): + if line.startswith("data:"): + data = json.loads(line[5:]) + if "pay_index" in data and data.get("status") == "paid": + yield data["label"] except ( OSError, httpx.ReadError, diff --git a/tests/core/views/test_api.py b/tests/core/views/test_api.py index 4f8a11540..92a2d2eae 100644 --- a/tests/core/views/test_api.py +++ b/tests/core/views/test_api.py @@ -6,12 +6,12 @@ import pytest from lnbits import bolt11 from lnbits.core.models import Payment -from lnbits.core.views.api import api_payment +from lnbits.core.views.api import api_auditor, api_payment from lnbits.db import DB_TYPE, SQLITE from lnbits.settings import get_wallet_class from tests.conftest import CreateInvoiceData, api_payments_create_invoice -from ...helpers import get_random_invoice_data, is_fake +from ...helpers import get_random_invoice_data, is_fake, pay_real_invoice WALLET = get_wallet_class() @@ -320,11 +320,17 @@ async def test_create_invoice_with_unhashed_description(client, inkey_headers_to return invoice +async def get_node_balance_sats(): + audit = await api_auditor() + return audit["node_balance_msats"] / 1000 + + @pytest.mark.asyncio @pytest.mark.skipif(is_fake, reason="this only works in regtest") async def test_pay_real_invoice( client, real_invoice, adminkey_headers_from, inkey_headers_from ): + prev_balance = await get_node_balance_sats() response = await client.post( "/api/v1/payments", json=real_invoice, headers=adminkey_headers_from ) @@ -337,5 +343,46 @@ async def test_pay_real_invoice( response = await api_payment( invoice["payment_hash"], inkey_headers_from["X-Api-Key"] ) - assert type(response) == dict - assert response["paid"] is True + assert response["paid"] + + status = await WALLET.get_payment_status(invoice["payment_hash"]) + assert status.paid + + await asyncio.sleep(0.3) + balance = await get_node_balance_sats() + assert prev_balance - balance == 100 + + +@pytest.mark.asyncio +@pytest.mark.skipif(is_fake, reason="this only works in regtest") +async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_from): + prev_balance = await get_node_balance_sats() + create_invoice = CreateInvoiceData(out=False, amount=1000, memo="test") + response = await client.post( + "/api/v1/payments", + json=create_invoice.dict(), + headers=adminkey_headers_from, + ) + assert response.status_code < 300 + invoice = response.json() + response = await api_payment( + invoice["payment_hash"], inkey_headers_from["X-Api-Key"] + ) + assert not response["paid"] + + async def listen(): + async for payment_hash in get_wallet_class().paid_invoices_stream(): + assert payment_hash == invoice["payment_hash"] + return + + task = asyncio.create_task(listen()) + pay_real_invoice(invoice["payment_request"]) + await asyncio.wait_for(task, timeout=3) + response = await api_payment( + invoice["payment_hash"], inkey_headers_from["X-Api-Key"] + ) + assert response["paid"] + + await asyncio.sleep(0.3) + balance = await get_node_balance_sats() + assert balance - prev_balance == create_invoice.amount diff --git a/tests/helpers.py b/tests/helpers.py index f00d684f2..d386ac2d0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -63,13 +63,12 @@ def run_cmd_json(cmd: str) -> dict: def get_real_invoice(sats: int) -> dict: - msats = sats * 1000 - return run_cmd_json(f"{docker_lightning_cli} addinvoice {msats}") + return run_cmd_json(f"{docker_lightning_cli} addinvoice {sats}") def pay_real_invoice(invoice: str) -> Popen: return Popen( - f"{docker_lightning_cli} payinvoice {invoice}", + f"{docker_lightning_cli} payinvoice --force {invoice}", shell=True, stdin=PIPE, stdout=PIPE,