fix lnurlp and lnurlpayout mypy issue

This commit is contained in:
dni ⚡
2023-01-04 13:24:45 +01:00
parent 268d101189
commit 9d2ef274b9
8 changed files with 60 additions and 85 deletions

View File

@@ -4,7 +4,6 @@ import json
import httpx import httpx
from loguru import logger from loguru import logger
from lnbits.core import db as core_db
from lnbits.core.crud import update_payment_extra from lnbits.core.crud import update_payment_extra
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.helpers import get_current_extension_name from lnbits.helpers import get_current_extension_name
@@ -22,9 +21,8 @@ async def wait_for_paid_invoices():
await on_invoice_paid(payment) await on_invoice_paid(payment)
async def on_invoice_paid(payment: Payment) -> None: async def on_invoice_paid(payment: Payment):
if payment.extra.get("tag") != "lnurlp": if not payment.extra or payment.extra.get("tag") != "lnurlp":
# not an lnurlp invoice
return return
if payment.extra.get("wh_status"): if payment.extra.get("wh_status"):
@@ -35,22 +33,24 @@ async def on_invoice_paid(payment: Payment) -> None:
if pay_link and pay_link.webhook_url: if pay_link and pay_link.webhook_url:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
kwargs = { r: httpx.Response = await client.post(
"json": { pay_link.webhook_url,
json={
"payment_hash": payment.payment_hash, "payment_hash": payment.payment_hash,
"payment_request": payment.bolt11, "payment_request": payment.bolt11,
"amount": payment.amount, "amount": payment.amount,
"comment": payment.extra.get("comment"), "comment": payment.extra.get("comment"),
"lnurlp": pay_link.id, "lnurlp": pay_link.id,
"lnurlp": pay_link.id,
"body": json.loads(pay_link.webhook_body)
if pay_link.webhook_body
else "",
}, },
"timeout": 40, headers=json.loads(pay_link.webhook_headers)
} if pay_link.webhook_headers
if pay_link.webhook_body: else None,
kwargs["json"]["body"] = json.loads(pay_link.webhook_body) timeout=40,
if pay_link.webhook_headers: )
kwargs["headers"] = json.loads(pay_link.webhook_headers)
r: httpx.Response = await client.post(pay_link.webhook_url, **kwargs)
await mark_webhook_sent( await mark_webhook_sent(
payment.payment_hash, payment.payment_hash,
r.status_code, r.status_code,

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Request from fastapi import Depends, Request
from fastapi.params import Depends
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse

View File

@@ -1,9 +1,7 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from fastapi import Request from fastapi import Depends, Query, Request
from fastapi.param_functions import Query
from fastapi.params import Depends
from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl # type: ignore from lnurl.exceptions import InvalidUrl as LnurlInvalidUrl # type: ignore
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@@ -36,7 +34,8 @@ async def api_links(
wallet_ids = [wallet.wallet.id] wallet_ids = [wallet.wallet.id]
if all_wallets: if all_wallets:
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids user = await get_user(wallet.wallet.user)
wallet_ids = user.wallet_ids if user else []
try: try:
return [ return [
@@ -137,6 +136,7 @@ async def api_link_create_or_update(
link = await update_pay_link(**data.dict(), link_id=link_id) link = await update_pay_link(**data.dict(), link_id=link_id)
else: else:
link = await create_pay_link(data, wallet_id=wallet.wallet.id) link = await create_pay_link(data, wallet_id=wallet.wallet.id)
assert link
return {**link.dict(), "lnurl": link.lnurl(request)} return {**link.dict(), "lnurl": link.lnurl(request)}

View File

@@ -53,7 +53,7 @@ async def get_lnurlpayouts(wallet_ids: Union[str, List[str]]) -> List[lnurlpayou
f"SELECT * FROM lnurlpayout.lnurlpayouts WHERE wallet IN ({q})", (*wallet_ids,) f"SELECT * FROM lnurlpayout.lnurlpayouts WHERE wallet IN ({q})", (*wallet_ids,)
) )
return [lnurlpayout(**row) if row else None for row in rows] return [lnurlpayout(**row) for row in rows]
async def delete_lnurlpayout(lnurlpayout_id: str) -> None: async def delete_lnurlpayout(lnurlpayout_id: str) -> None:

View File

@@ -2,14 +2,13 @@ import asyncio
from http import HTTPStatus from http import HTTPStatus
import httpx import httpx
from lnurl import decode as lnurl_decode
from loguru import logger from loguru import logger
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core import db as core_db
from lnbits.core.crud import get_wallet from lnbits.core.crud import get_wallet
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.core.services import pay_invoice from lnbits.core.services import pay_invoice
from lnbits.core.views.api import api_payments_decode
from lnbits.helpers import get_current_extension_name from lnbits.helpers import get_current_extension_name
from lnbits.tasks import register_invoice_listener from lnbits.tasks import register_invoice_listener
@@ -35,6 +34,7 @@ async def on_invoice_paid(payment: Payment) -> None:
# Check the wallet balance is more than the threshold # Check the wallet balance is more than the threshold
wallet = await get_wallet(lnurlpayout_link.wallet) wallet = await get_wallet(lnurlpayout_link.wallet)
assert wallet
threshold = lnurlpayout_link.threshold + (lnurlpayout_link.threshold * 0.02) threshold = lnurlpayout_link.threshold + (lnurlpayout_link.threshold * 0.02)
if wallet.balance < threshold: if wallet.balance < threshold:
@@ -42,14 +42,10 @@ async def on_invoice_paid(payment: Payment) -> None:
# Get the invoice from the LNURL to pay # Get the invoice from the LNURL to pay
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
url = await api_payments_decode({"data": lnurlpayout_link.lnurlpay}) url = lnurl_decode(lnurlpayout_link.lnurlpay)
if str(url["domain"])[0:4] != "http":
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="LNURL broken"
)
try: try:
r = await client.get(str(url["domain"]), timeout=40) r = await client.get(str(url), timeout=40)
res = r.json() res = r.json()
try: try:
r = await client.get( r = await client.get(

View File

@@ -1,16 +1,13 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Request from fastapi import Depends, Request
from fastapi.params import Depends
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
from lnbits.core.models import User from lnbits.core.models import User
from lnbits.decorators import check_user_exists from lnbits.decorators import check_user_exists
from . import lnurlpayout_ext, lnurlpayout_renderer from . import lnurlpayout_ext, lnurlpayout_renderer
from .crud import get_lnurlpayout
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")

View File

@@ -1,13 +1,10 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import Query from fastapi import Depends, Query
from fastapi.params import Depends from lnurl import decode as lnurl_decode
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from lnbits.core.crud import get_payments, get_user from lnbits.core.crud import get_user
from lnbits.core.models import Payment
from lnbits.core.services import create_invoice
from lnbits.core.views.api import api_payment, api_payments_decode
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key
from . import lnurlpayout_ext from . import lnurlpayout_ext
@@ -18,8 +15,9 @@ from .crud import (
get_lnurlpayout_from_wallet, get_lnurlpayout_from_wallet,
get_lnurlpayouts, get_lnurlpayouts,
) )
from .models import CreateLnurlPayoutData, lnurlpayout from .models import CreateLnurlPayoutData
from .tasks import on_invoice_paid
# from .tasks import on_invoice_paid
@lnurlpayout_ext.get("/api/v1/lnurlpayouts", status_code=HTTPStatus.OK) @lnurlpayout_ext.get("/api/v1/lnurlpayouts", status_code=HTTPStatus.OK)
@@ -28,7 +26,8 @@ async def api_lnurlpayouts(
): ):
wallet_ids = [wallet.wallet.id] wallet_ids = [wallet.wallet.id]
if all_wallets: if all_wallets:
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids user = await get_user(wallet.wallet.user)
wallet_ids = user.wallet_ids if user else []
return [lnurlpayout.dict() for lnurlpayout in await get_lnurlpayouts(wallet_ids)] return [lnurlpayout.dict() for lnurlpayout in await get_lnurlpayouts(wallet_ids)]
@@ -42,24 +41,16 @@ async def api_lnurlpayout_create(
status_code=HTTPStatus.FORBIDDEN, status_code=HTTPStatus.FORBIDDEN,
detail="Wallet already has lnurlpayout set", detail="Wallet already has lnurlpayout set",
) )
return _ = lnurl_decode(data.lnurlpay)
url = await api_payments_decode({"data": data.lnurlpay})
if "domain" not in url:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="LNURL could not be decoded"
)
return
if str(url["domain"])[0:4] != "http":
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail="Not valid LNURL")
return
lnurlpayout = await create_lnurlpayout( lnurlpayout = await create_lnurlpayout(
wallet_id=wallet.wallet.id, admin_key=wallet.wallet.adminkey, data=data wallet_id=wallet.wallet.id, admin_key=wallet.wallet.adminkey, data=data
) )
if not lnurlpayout: if not lnurlpayout:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Failed to save LNURLPayout" status_code=HTTPStatus.FORBIDDEN, detail="Failed to save LNURLPayout"
) )
return
return lnurlpayout.dict() return lnurlpayout.dict()
@@ -83,36 +74,29 @@ async def api_lnurlpayout_delete(
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@lnurlpayout_ext.get("/api/v1/lnurlpayouts/{lnurlpayout_id}", status_code=HTTPStatus.OK) # TODO: what is this?!
async def api_lnurlpayout_check(
lnurlpayout_id: str, wallet: WalletTypeInfo = Depends(get_key_type) # @lnurlpayout_ext.get("/api/v1/lnurlpayouts/{lnurlpayout_id}", status_code=HTTPStatus.OK)
): # async def api_lnurlpayout_check(
lnurlpayout = await get_lnurlpayout(lnurlpayout_id) # lnurlpayout_id: str, wallet: WalletTypeInfo = Depends(get_key_type)
## THIS # ):
mock_payment = Payment( # lnurlpayout = await get_lnurlpayout(lnurlpayout_id)
checking_id="mock", # ## THIS
pending=False, # mock_payment = Payment(
amount=1, # checking_id="mock",
fee=1, # pending=False,
time=0000, # amount=1,
bolt11="mock", # fee=1,
preimage="mock", # time=0000,
payment_hash="mock", # bolt11="mock",
wallet_id=lnurlpayout.wallet, # preimage="mock",
) # payment_hash="mock",
## INSTEAD OF THIS # wallet_id=lnurlpayout.wallet,
# payments = await get_payments(
# wallet_id=lnurlpayout.wallet, complete=True, pending=False, outgoing=True, incoming=True
# ) # )
# ## INSTEAD OF THIS
# # payments = await get_payments(
# # wallet_id=lnurlpayout.wallet, complete=True, pending=False, outgoing=True, incoming=True
# # )
result = await on_invoice_paid(mock_payment) # result = await on_invoice_paid(mock_payment)
return # return
# get payouts func
# lnurlpayouts = await get_lnurlpayouts(wallet_ids)
# for lnurlpayout in lnurlpayouts:
# payments = await get_payments(
# wallet_id=lnurlpayout.wallet, complete=True, pending=False, outgoing=True, incoming=True
# )
# await on_invoice_paid(payments[0])

View File

@@ -98,7 +98,6 @@ exclude = """(?x)(
| ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnaddress.
| ^lnbits/extensions/lndhub. | ^lnbits/extensions/lndhub.
| ^lnbits/extensions/lnurldevice. | ^lnbits/extensions/lnurldevice.
| ^lnbits/extensions/lnurlp.
| ^lnbits/extensions/offlineshop. | ^lnbits/extensions/offlineshop.
| ^lnbits/extensions/satspay. | ^lnbits/extensions/satspay.
| ^lnbits/extensions/streamalerts. | ^lnbits/extensions/streamalerts.