diff --git a/lnbits/extensions/satspay/crud.py b/lnbits/extensions/satspay/crud.py index 784338387..4fb14695e 100644 --- a/lnbits/extensions/satspay/crud.py +++ b/lnbits/extensions/satspay/crud.py @@ -12,14 +12,13 @@ from . import db from .helpers import fetch_onchain_balance from .models import Charges, CreateCharge, SatsPayThemes -###############CHARGES########################## - -async def create_charge(user: str, data: CreateCharge) -> Charges: +async def create_charge(user: str, data: CreateCharge) -> Optional[Charges]: data = CreateCharge(**data.dict()) charge_id = urlsafe_short_hash() if data.onchainwallet: config = await get_config(user) + assert config data.extra = json.dumps( {"mempool_endpoint": config.mempool_endpoint, "network": config.network} ) @@ -92,7 +91,7 @@ async def update_charge(charge_id: str, **kwargs) -> Optional[Charges]: return Charges.from_row(row) if row else None -async def get_charge(charge_id: str) -> Charges: +async def get_charge(charge_id: str) -> Optional[Charges]: row = await db.fetchone("SELECT * FROM satspay.charges WHERE id = ?", (charge_id,)) return Charges.from_row(row) if row else None @@ -111,6 +110,7 @@ async def delete_charge(charge_id: str) -> None: async def check_address_balance(charge_id: str) -> Optional[Charges]: charge = await get_charge(charge_id) + assert charge if not charge.paid: if charge.onchainaddress: @@ -131,7 +131,7 @@ async def check_address_balance(charge_id: str) -> Optional[Charges]: ################## SETTINGS ################### -async def save_theme(data: SatsPayThemes, css_id: str = None): +async def save_theme(data: SatsPayThemes, css_id: Optional[str]): # insert or update if css_id: await db.execute( @@ -162,7 +162,7 @@ async def save_theme(data: SatsPayThemes, css_id: str = None): return await get_theme(css_id) -async def get_theme(css_id: str) -> SatsPayThemes: +async def get_theme(css_id: str) -> Optional[SatsPayThemes]: row = await db.fetchone("SELECT * FROM satspay.themes WHERE css_id = ?", (css_id,)) return SatsPayThemes.from_row(row) if row else None diff --git a/lnbits/extensions/satspay/helpers.py b/lnbits/extensions/satspay/helpers.py index b21a3ae29..8596d3684 100644 --- a/lnbits/extensions/satspay/helpers.py +++ b/lnbits/extensions/satspay/helpers.py @@ -32,6 +32,7 @@ def public_charge(charge: Charges): async def call_webhook(charge: Charges): async with httpx.AsyncClient() as client: try: + assert charge.webhook r = await client.post( charge.webhook, json=public_charge(charge), @@ -54,6 +55,8 @@ async def fetch_onchain_balance(charge: Charges): if charge.config.network == "Testnet" else charge.config.mempool_endpoint ) + assert endpoint + assert charge.onchainaddress async with httpx.AsyncClient() as client: r = await client.get(endpoint + "/api/address/" + charge.onchainaddress) return r.json()["chain_stats"]["funded_txo_sum"] diff --git a/lnbits/extensions/satspay/tasks.py b/lnbits/extensions/satspay/tasks.py index ce54b44a2..1f79d89ff 100644 --- a/lnbits/extensions/satspay/tasks.py +++ b/lnbits/extensions/satspay/tasks.py @@ -22,10 +22,14 @@ async def wait_for_paid_invoices(): async def on_invoice_paid(payment: Payment) -> None: + if not payment.extra: + return + if payment.extra.get("tag") != "charge": # not a charge invoice return + assert payment.memo charge = await get_charge(payment.memo) if not charge: logger.error("this should never happen", payment) @@ -33,6 +37,7 @@ async def on_invoice_paid(payment: Payment) -> None: await payment.set_pending(False) charge = await check_address_balance(charge_id=charge.id) + assert charge if charge.must_call_webhook(): resp = await call_webhook(charge) diff --git a/lnbits/extensions/satspay/views.py b/lnbits/extensions/satspay/views.py index 90f8a6b93..15a4403db 100644 --- a/lnbits/extensions/satspay/views.py +++ b/lnbits/extensions/satspay/views.py @@ -1,10 +1,7 @@ from http import HTTPStatus -from fastapi import Response -from fastapi.param_functions import Depends +from fastapi import Depends, HTTPException, Request, Response from fastapi.templating import Jinja2Templates -from starlette.exceptions import HTTPException -from starlette.requests import Request from starlette.responses import HTMLResponse from lnbits.core.models import User diff --git a/lnbits/extensions/satspay/views_api.py b/lnbits/extensions/satspay/views_api.py index 08c731cb2..98c338edc 100644 --- a/lnbits/extensions/satspay/views_api.py +++ b/lnbits/extensions/satspay/views_api.py @@ -1,9 +1,8 @@ import json from http import HTTPStatus -from fastapi import Depends, Query +from fastapi import Depends, HTTPException, Query from loguru import logger -from starlette.exceptions import HTTPException from lnbits.decorators import ( WalletTypeInfo, @@ -29,8 +28,6 @@ from .crud import ( from .helpers import call_webhook, public_charge from .models import CreateCharge, SatsPayThemes -#############################CHARGES########################## - @satspay_ext.post("/api/v1/charge") async def api_charge_create( @@ -38,6 +35,7 @@ async def api_charge_create( ): try: charge = await create_charge(user=wallet.wallet.user, data=data) + assert charge return { **charge.dict(), **{"time_elapsed": charge.time_elapsed}, @@ -51,13 +49,15 @@ async def api_charge_create( ) -@satspay_ext.put("/api/v1/charge/{charge_id}") +@satspay_ext.put( + "/api/v1/charge/{charge_id}", dependencies=[Depends(require_admin_key)] +) async def api_charge_update( data: CreateCharge, - wallet: WalletTypeInfo = Depends(require_admin_key), - charge_id=None, + charge_id: str, ): charge = await update_charge(charge_id=charge_id, data=data) + assert charge return charge.dict() @@ -78,10 +78,8 @@ async def api_charges_retrieve(wallet: WalletTypeInfo = Depends(get_key_type)): return "" -@satspay_ext.get("/api/v1/charge/{charge_id}") -async def api_charge_retrieve( - charge_id, wallet: WalletTypeInfo = Depends(get_key_type) -): +@satspay_ext.get("/api/v1/charge/{charge_id}", dependencies=[Depends(get_key_type)]) +async def api_charge_retrieve(charge_id: str): charge = await get_charge(charge_id) if not charge: @@ -97,8 +95,8 @@ async def api_charge_retrieve( } -@satspay_ext.delete("/api/v1/charge/{charge_id}") -async def api_charge_delete(charge_id, wallet: WalletTypeInfo = Depends(get_key_type)): +@satspay_ext.delete("/api/v1/charge/{charge_id}", dependencies=[Depends(get_key_type)]) +async def api_charge_delete(charge_id: str): charge = await get_charge(charge_id) if not charge: @@ -155,7 +153,7 @@ async def api_themes_save( theme = await save_theme(css_id=css_id, data=data) else: data.user = wallet.wallet.user - theme = await save_theme(data=data) + theme = await save_theme(data=data, css_id="no_id") return theme @@ -169,8 +167,8 @@ async def api_themes_retrieve(wallet: WalletTypeInfo = Depends(get_key_type)): return "" -@satspay_ext.delete("/api/v1/themes/{theme_id}") -async def api_theme_delete(theme_id, wallet: WalletTypeInfo = Depends(get_key_type)): +@satspay_ext.delete("/api/v1/themes/{theme_id}", dependencies=[Depends(get_key_type)]) +async def api_theme_delete(theme_id): theme = await get_theme(theme_id) if not theme: diff --git a/pyproject.toml b/pyproject.toml index e2116ed08..cbbafad65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,6 @@ exclude = """(?x)( | ^lnbits/extensions/livestream. | ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnurldevice. - | ^lnbits/extensions/satspay. | ^lnbits/extensions/watchonly. | ^lnbits/wallets/lnd_grpc_files. )"""