diff --git a/lnbits/core/services.py b/lnbits/core/services.py index c0ce6ac3f..3fa8d84f0 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -44,7 +44,7 @@ from .crud import ( update_super_user, ) from .helpers import to_valid_user_id -from .models import Payment +from .models import Payment, Wallet class PaymentFailure(Exception): @@ -172,7 +172,7 @@ async def pay_invoice( logger.debug(f"creating temporary internal payment with id {internal_id}") # create a new payment from this wallet - await create_payment( + new_payment = await create_payment( checking_id=internal_id, fee=0, pending=False, @@ -184,7 +184,7 @@ async def pay_invoice( # create a temporary payment here so we can check if # the balance is enough in the next step try: - await create_payment( + new_payment = await create_payment( checking_id=temp_id, fee=-fee_reserve_msat, conn=conn, @@ -215,6 +215,7 @@ async def pay_invoice( await update_payment_status( checking_id=internal_checking_id, pending=False, conn=conn ) + await send_payment_notification(wallet, new_payment) # notify receiver asynchronously from lnbits.tasks import internal_invoice_queue @@ -248,16 +249,11 @@ async def pay_invoice( conn=conn, ) wallet = await get_wallet(wallet_id, conn=conn) - if wallet: - await websocketUpdater( - wallet_id, - json.dumps( - { - "wallet_balance": wallet.balance or None, - "payment": payment._asdict(), - } - ), - ) + updated = await get_wallet_payment( + wallet_id, payment.checking_id, conn=conn + ) + if wallet and updated: + await send_payment_notification(wallet, updated) logger.debug(f"payment successful {payment.checking_id}") elif payment.checking_id is None and payment.ok is False: # payment failed @@ -431,6 +427,18 @@ def fee_reserve(amount_msat: int) -> int: return max(int(reserve_min), int(amount_msat * reserve_percent / 100.0)) +async def send_payment_notification(wallet: Wallet, payment: Payment): + await websocketUpdater( + wallet.id, + json.dumps( + { + "wallet_balance": wallet.balance, + "payment": payment.dict(), + } + ), + ) + + async def update_wallet_balance(wallet_id: str, amount: int): payment_hash, _ = await create_invoice( wallet_id=wallet_id, diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index 14f8478d1..4572f73e8 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -1,5 +1,4 @@ import asyncio -import json from typing import Dict, Optional import httpx @@ -11,7 +10,7 @@ from lnbits.tasks import SseListenersDict, register_invoice_listener from . import db from .crud import get_balance_notify, get_wallet from .models import Payment -from .services import get_balance_delta, switch_to_voidwallet, websocketUpdater +from .services import get_balance_delta, send_payment_notification, switch_to_voidwallet api_invoice_listeners: Dict[str, asyncio.Queue] = SseListenersDict( "api_invoice_listeners" @@ -123,15 +122,7 @@ async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue): await dispatch_api_invoice_listeners(payment) wallet = await get_wallet(payment.wallet_id) if wallet: - await websocketUpdater( - payment.wallet_id, - json.dumps( - { - "wallet_balance": wallet.balance or None, - "payment": payment.dict(), - } - ), - ) + await send_payment_notification(wallet, payment) # dispatch webhook if payment.webhook and not payment.webhook_status: await dispatch_webhook(payment) diff --git a/tests/conftest.py b/tests/conftest.py index 48788b671..2e6a28be3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import asyncio +import pytest import pytest_asyncio +from fastapi.testclient import TestClient from httpx import AsyncClient from lnbits.app import create_app @@ -41,6 +43,11 @@ async def client(app): await client.aclose() +@pytest.fixture(scope="session") +def test_client(app): + return TestClient(app) + + @pytest_asyncio.fixture(scope="session") async def db(): yield Database("database") @@ -63,6 +70,12 @@ async def from_wallet(from_user): yield wallet +@pytest.fixture +def from_wallet_ws(from_wallet, test_client): + with test_client.websocket_connect(f"/api/v1/ws/{from_wallet.id}") as ws: + yield ws + + @pytest_asyncio.fixture(scope="session") async def to_user(): user = await create_account() @@ -80,6 +93,12 @@ async def to_wallet(to_user): yield wallet +@pytest.fixture +def to_wallet_ws(to_wallet, test_client): + with test_client.websocket_connect(f"/api/v1/ws/{to_wallet.id}") as ws: + yield ws + + @pytest_asyncio.fixture(scope="session") async def inkey_headers_from(from_wallet): wallet = from_wallet diff --git a/tests/core/views/test_api.py b/tests/core/views/test_api.py index a645a4dba..57250cdb1 100644 --- a/tests/core/views/test_api.py +++ b/tests/core/views/test_api.py @@ -113,14 +113,28 @@ async def test_create_invoice_custom_expiry(client, inkey_headers_to): # check POST /api/v1/payments: make payment @pytest.mark.asyncio -async def test_pay_invoice(client, invoice, adminkey_headers_from): +async def test_pay_invoice( + client, from_wallet_ws, to_wallet_ws, invoice, adminkey_headers_from +): data = {"out": True, "bolt11": invoice["payment_request"]} response = await client.post( "/api/v1/payments", json=data, headers=adminkey_headers_from ) assert response.status_code < 300 - assert len(response.json()["payment_hash"]) == 64 - assert len(response.json()["checking_id"]) > 0 + invoice = response.json() + assert len(invoice["payment_hash"]) == 64 + assert len(invoice["checking_id"]) > 0 + + data = from_wallet_ws.receive_json() + assert "wallet_balance" in data + payment = Payment(**data["payment"]) + assert payment.payment_hash == invoice["payment_hash"] + + # websocket from to_wallet cant be tested before https://github.com/lnbits/lnbits/pull/1793 + # data = to_wallet_ws.receive_json() + # assert "wallet_balance" in data + # payment = Payment(**data["payment"]) + # assert payment.payment_hash == invoice["payment_hash"] # check GET /api/v1/payments/: payment status @@ -330,7 +344,7 @@ async def get_node_balance_sats(): @pytest.mark.asyncio @pytest.mark.skipif(is_fake, reason="this only works in regtest") async def test_pay_real_invoice( - client, real_invoice, adminkey_headers_from, inkey_headers_from + client, real_invoice, adminkey_headers_from, inkey_headers_from, from_wallet_ws ): prev_balance = await get_node_balance_sats() response = await client.post( @@ -341,6 +355,11 @@ async def test_pay_real_invoice( assert len(invoice["payment_hash"]) == 64 assert len(invoice["checking_id"]) > 0 + data = from_wallet_ws.receive_json() + assert "wallet_balance" in data + payment = Payment(**data["payment"]) + assert payment.payment_hash == invoice["payment_hash"] + # check the payment status response = await api_payment( invoice["payment_hash"], inkey_headers_from["X-Api-Key"]