Merge remote-tracking branch 'origin/FastAPI' into FastAPI

This commit is contained in:
benarc
2021-11-30 10:08:46 +00:00
7 changed files with 39 additions and 17 deletions

View File

@@ -172,6 +172,8 @@ async def pay_invoice(
) )
await delete_payment(temp_id, conn=conn) await delete_payment(temp_id, conn=conn)
else: else:
async with db.connect() as conn:
await delete_payment(temp_id, conn=conn)
raise PaymentFailure( raise PaymentFailure(
payment.error_message payment.error_message
or "Payment failed, but backend didn't give us an error message." or "Payment failed, but backend didn't give us an error message."

View File

@@ -72,7 +72,7 @@ async def api_update_wallet(
@core_app.get("/api/v1/payments") @core_app.get("/api/v1/payments")
async def api_payments(wallet: WalletTypeInfo = Depends(get_key_type)): async def api_payments(wallet: WalletTypeInfo = Depends(get_key_type)):
await get_payments(wallet_id=wallet.wallet.id, pending=True, complete=True) await get_payments(wallet_id=wallet.wallet.id, pending=True, complete=True)
pendingPayments = await get_payments(wallet_id=wallet.wallet.id, pending=True) pendingPayments = await get_payments(wallet_id=wallet.wallet.id, pending=True, exclude_uncheckable=True)
for payment in pendingPayments: for payment in pendingPayments:
await check_invoice_status( await check_invoice_status(
wallet_id=payment.wallet_id, payment_hash=payment.payment_hash wallet_id=payment.wallet_id, payment_hash=payment.payment_hash
@@ -193,7 +193,8 @@ async def api_payments_create(
invoiceData: CreateInvoiceData = Body(...), invoiceData: CreateInvoiceData = Body(...),
): ):
if wallet.wallet_type < 0 or wallet.wallet_type > 2: if wallet.wallet_type < 0 or wallet.wallet_type > 2:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Key is invalid") raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Key is invalid")
if invoiceData.out is True and wallet.wallet_type == 0: if invoiceData.out is True and wallet.wallet_type == 0:
if not invoiceData.bolt11: if not invoiceData.bolt11:
@@ -204,7 +205,8 @@ async def api_payments_create(
return await api_payments_pay_invoice( return await api_payments_pay_invoice(
invoiceData.bolt11, wallet.wallet invoiceData.bolt11, wallet.wallet
) # admin key ) # admin key
return await api_payments_create_invoice(invoiceData, wallet.wallet) # invoice key # invoice key
return await api_payments_create_invoice(invoiceData, wallet.wallet)
class CreateLNURLData(BaseModel): class CreateLNURLData(BaseModel):
@@ -372,14 +374,16 @@ async def api_lnurlscan(code: str):
params.update(callback=url) # with k1 already in it params.update(callback=url) # with k1 already in it
lnurlauth_key = g().wallet.lnurlauth_key(domain) lnurlauth_key = g().wallet.lnurlauth_key(domain)
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) params.update(
pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
else: else:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(url, timeout=5) r = await client.get(url, timeout=5)
if r.is_error: if r.is_error:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.SERVICE_UNAVAILABLE, status_code=HTTPStatus.SERVICE_UNAVAILABLE,
detail={"domain": domain, "message": "failed to get parameters"}, detail={"domain": domain,
"message": "failed to get parameters"},
) )
try: try:
@@ -409,7 +413,8 @@ async def api_lnurlscan(code: str):
if tag == "withdrawRequest": if tag == "withdrawRequest":
params.update(kind="withdraw") params.update(kind="withdraw")
params.update(fixed=data["minWithdrawable"] == data["maxWithdrawable"]) params.update(fixed=data["minWithdrawable"]
== data["maxWithdrawable"])
# callback with k1 already in it # callback with k1 already in it
parsed_callback: ParseResult = urlparse(data["callback"]) parsed_callback: ParseResult = urlparse(data["callback"])

View File

@@ -2,6 +2,7 @@ from sqlite3 import Row
from fastapi.param_functions import Query from fastapi.param_functions import Query
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
class CreateUserData(BaseModel): class CreateUserData(BaseModel):
@@ -22,8 +23,8 @@ class Users(BaseModel):
id: str id: str
name: str name: str
admin: str admin: str
email: str email: Optional[str] = None
password: str password: Optional[str] = None
class Wallets(BaseModel): class Wallets(BaseModel):

View File

@@ -23,7 +23,7 @@ from .crud import (
) )
from .models import CreateUserData, CreateUserWallet from .models import CreateUserData, CreateUserWallet
### Users # Users
@usermanager_ext.get("/api/v1/users", status_code=HTTPStatus.OK) @usermanager_ext.get("/api/v1/users", status_code=HTTPStatus.OK)
@@ -63,7 +63,7 @@ async def api_usermanager_users_delete(
raise HTTPException(status_code=HTTPStatus.NO_CONTENT) raise HTTPException(status_code=HTTPStatus.NO_CONTENT)
###Activate Extension # Activate Extension
@usermanager_ext.post("/api/v1/extensions") @usermanager_ext.post("/api/v1/extensions")
@@ -79,7 +79,7 @@ async def api_usermanager_activate_extension(
return {"extension": "updated"} return {"extension": "updated"}
###Wallets # Wallets
@usermanager_ext.post("/api/v1/wallets") @usermanager_ext.post("/api/v1/wallets")
@@ -98,7 +98,7 @@ async def api_usermanager_wallets(wallet: WalletTypeInfo = Depends(get_key_type)
return [wallet.dict() for wallet in await get_usermanager_wallets(admin_id)] return [wallet.dict() for wallet in await get_usermanager_wallets(admin_id)]
@usermanager_ext.get("/api/v1/wallets/{wallet_id}") @usermanager_ext.get("/api/v1/transactions/{wallet_id}")
async def api_usermanager_wallet_transactions( async def api_usermanager_wallet_transactions(
wallet_id, wallet: WalletTypeInfo = Depends(get_key_type) wallet_id, wallet: WalletTypeInfo = Depends(get_key_type)
): ):

View File

@@ -23,7 +23,7 @@ class Jinja2Templates(templating.Jinja2Templates):
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
@jinja2.contextfunction @jinja2.contextfunction
def url_for(context: dict, name: str, **path_params: typing.Any) -> str: def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request: Request = context["request"] # type: starlette.requests.Request request: Request = context["request"]
return request.app.url_path_for(name, **path_params) return request.app.url_path_for(name, **path_params)
def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams: def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams:

View File

@@ -5,6 +5,8 @@ import base64
from os import getenv from os import getenv
from typing import Optional, Dict, AsyncGenerator from typing import Optional, Dict, AsyncGenerator
from lnbits import bolt11 as lnbits_bolt11
from .base import ( from .base import (
StatusResponse, StatusResponse,
InvoiceResponse, InvoiceResponse,
@@ -21,7 +23,8 @@ class LndRestWallet(Wallet):
endpoint = getenv("LND_REST_ENDPOINT") endpoint = getenv("LND_REST_ENDPOINT")
endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint endpoint = endpoint[:-1] if endpoint.endswith("/") else endpoint
endpoint = ( endpoint = (
"https://" + endpoint if not endpoint.startswith("http") else endpoint "https://" +
endpoint if not endpoint.startswith("http") else endpoint
) )
self.endpoint = endpoint self.endpoint = endpoint
@@ -89,10 +92,21 @@ class LndRestWallet(Wallet):
async def pay_invoice(self, bolt11: str) -> PaymentResponse: async def pay_invoice(self, bolt11: str) -> PaymentResponse:
async with httpx.AsyncClient(verify=self.cert) as client: async with httpx.AsyncClient(verify=self.cert) as client:
# set the fee limit for the payment
invoice = lnbits_bolt11.decode(bolt11)
lnrpcFeeLimit = dict()
if invoice.amount_msat > 1000_000:
lnrpcFeeLimit["percent"] = "1" # in percent
else:
lnrpcFeeLimit["fixed"] = "10" # in sat
r = await client.post( r = await client.post(
url=f"{self.endpoint}/v1/channels/transactions", url=f"{self.endpoint}/v1/channels/transactions",
headers=self.auth, headers=self.auth,
json={"payment_request": bolt11}, json={
"payment_request": bolt11,
"fee_limit": lnrpcFeeLimit,
},
timeout=180, timeout=180,
) )
@@ -168,7 +182,8 @@ class LndRestWallet(Wallet):
except: except:
continue continue
payment_hash = base64.b64decode(inv["r_hash"]).hex() payment_hash = base64.b64decode(
inv["r_hash"]).hex()
yield payment_hash yield payment_hash
except (OSError, httpx.ConnectError, httpx.ReadError): except (OSError, httpx.ConnectError, httpx.ReadError):
pass pass

View File

@@ -1,2 +1 @@
[mypy] [mypy]
plugins = trio_typing.plugin