mirror of
https://github.com/lnbits/lnbits.git
synced 2025-06-28 09:40:59 +02:00
fix: FastAPify how data or exceptions are returned
FastAPI handles returning HTTPStatus codes differently than Quart did
This commit is contained in:
parent
d9849d43d2
commit
fa08177317
@ -31,23 +31,18 @@ from ..tasks import api_invoice_listeners
|
|||||||
|
|
||||||
@core_app.get("/api/v1/wallet")
|
@core_app.get("/api/v1/wallet")
|
||||||
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||||
return (
|
return {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat},
|
||||||
{"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat},
|
|
||||||
HTTPStatus.OK,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@core_app.put("/api/v1/wallet/{new_name}")
|
@core_app.put("/api/v1/wallet/{new_name}")
|
||||||
async def api_update_wallet(new_name: str, wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_update_wallet(new_name: str, wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||||
await update_wallet(wallet.wallet.id, new_name)
|
await update_wallet(wallet.wallet.id, new_name)
|
||||||
return (
|
return {
|
||||||
{
|
"id": wallet.wallet.id,
|
||||||
"id": wallet.wallet.id,
|
"name": wallet.wallet.name,
|
||||||
"name": wallet.wallet.name,
|
"balance": wallet.wallet.balance_msat,
|
||||||
"balance": wallet.wallet.balance_msat,
|
}
|
||||||
},
|
|
||||||
HTTPStatus.OK,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/payments")
|
@core_app.get("/api/v1/payments")
|
||||||
@ -92,7 +87,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
|
|||||||
conn=conn,
|
conn=conn,
|
||||||
)
|
)
|
||||||
except InvoiceFailure as e:
|
except InvoiceFailure as e:
|
||||||
return {"message": str(e)}, 520
|
raise HTTPException(status_code=520, detail=str(e))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
@ -128,16 +123,15 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
|
|||||||
except (httpx.ConnectError, httpx.RequestError):
|
except (httpx.ConnectError, httpx.RequestError):
|
||||||
lnurl_response = False
|
lnurl_response = False
|
||||||
|
|
||||||
return (
|
return {
|
||||||
{
|
"payment_hash": invoice.payment_hash,
|
||||||
"payment_hash": invoice.payment_hash,
|
"payment_request": payment_request,
|
||||||
"payment_request": payment_request,
|
# maintain backwards compatibility with API clients:
|
||||||
# maintain backwards compatibility with API clients:
|
"checking_id": invoice.payment_hash,
|
||||||
"checking_id": invoice.payment_hash,
|
"lnurl_response": lnurl_response,
|
||||||
"lnurl_response": lnurl_response,
|
}
|
||||||
},
|
|
||||||
HTTPStatus.CREATED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
|
async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
|
||||||
@ -147,26 +141,34 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
|
|||||||
payment_request=bolt11,
|
payment_request=bolt11,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"message": str(e)}, HTTPStatus.BAD_REQUEST
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return {"message": str(e)}, HTTPStatus.FORBIDDEN
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
except PaymentFailure as e:
|
except PaymentFailure as e:
|
||||||
return {"message": str(e)}, 520
|
raise HTTPException(
|
||||||
|
status_code=520,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
return (
|
return {
|
||||||
{
|
"payment_hash": payment_hash,
|
||||||
"payment_hash": payment_hash,
|
# maintain backwards compatibility with API clients:
|
||||||
# maintain backwards compatibility with API clients:
|
"checking_id": payment_hash,
|
||||||
"checking_id": payment_hash,
|
}
|
||||||
},
|
|
||||||
HTTPStatus.CREATED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@core_app.post("/api/v1/payments", deprecated=True,
|
@core_app.post("/api/v1/payments", deprecated=True,
|
||||||
description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead")
|
description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead",
|
||||||
|
status_code=HTTPStatus.CREATED)
|
||||||
async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True,
|
async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True,
|
||||||
invoiceData: Optional[CreateInvoiceData] = Body(None),
|
invoiceData: Optional[CreateInvoiceData] = Body(None),
|
||||||
bolt11: Optional[str] = Query(None)):
|
bolt11: Optional[str] = Query(None)):
|
||||||
@ -201,33 +203,33 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
|
|||||||
if r.is_error:
|
if r.is_error:
|
||||||
raise httpx.ConnectError
|
raise httpx.ConnectError
|
||||||
except (httpx.ConnectError, httpx.RequestError):
|
except (httpx.ConnectError, httpx.RequestError):
|
||||||
return (
|
raise HTTPException(
|
||||||
{"message": f"Failed to connect to {domain}."},
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
HTTPStatus.BAD_REQUEST,
|
detail=f"Failed to connect to {domain}."
|
||||||
)
|
)
|
||||||
|
|
||||||
params = json.loads(r.text)
|
params = json.loads(r.text)
|
||||||
if params.get("status") == "ERROR":
|
if params.get("status") == "ERROR":
|
||||||
return ({"message": f"{domain} said: '{params.get('reason', '')}'"},
|
raise HTTPException(
|
||||||
HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
detail=f"{domain} said: '{params.get('reason', '')}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
invoice = bolt11.decode(params["pr"])
|
invoice = bolt11.decode(params["pr"])
|
||||||
if invoice.amount_msat != data.amount:
|
if invoice.amount_msat != data.amount:
|
||||||
return (
|
raise HTTPException(
|
||||||
{
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
"message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}."
|
detail=f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}."
|
||||||
},
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if invoice.description_hash != g().data["description_hash"]:
|
if invoice.description_hash != g().data["description_hash"]:
|
||||||
return (
|
raise HTTPException(
|
||||||
{
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
"message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}."
|
detail=f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}."
|
||||||
},
|
|
||||||
HTTPStatus.BAD_REQUEST,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
|
|
||||||
if params.get("successAction"):
|
if params.get("successAction"):
|
||||||
@ -242,15 +244,13 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
|
|||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return {
|
||||||
{
|
"success_action": params.get("successAction"),
|
||||||
"success_action": params.get("successAction"),
|
"payment_hash": payment_hash,
|
||||||
"payment_hash": payment_hash,
|
# maintain backwards compatibility with API clients:
|
||||||
# maintain backwards compatibility with API clients:
|
"checking_id": payment_hash,
|
||||||
"checking_id": payment_hash,
|
}
|
||||||
},
|
|
||||||
HTTPStatus.CREATED,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def subscribe(request: Request, wallet: Wallet):
|
async def subscribe(request: Request, wallet: Wallet):
|
||||||
this_wallet_id = wallet.wallet.id
|
this_wallet_id = wallet.wallet.id
|
||||||
@ -273,20 +273,21 @@ async def subscribe(request: Request, wallet: Wallet):
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
typ, data = await send_queue.get()
|
typ, data = await send_queue.get()
|
||||||
message = [f"event: {typ}".encode("utf-8")]
|
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
jdata = json.dumps(dict(data.dict(), pending=False))
|
jdata = json.dumps(dict(data.dict(), pending=False))
|
||||||
message.append(f"data: {jdata}".encode("utf-8"))
|
|
||||||
|
|
||||||
yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8"))
|
# yield dict(id=1, event="this", data="1234")
|
||||||
|
# await asyncio.sleep(2)
|
||||||
|
yield dict(data=jdata, event=typ)
|
||||||
|
# yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8"))
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/payments/sse")
|
@core_app.get("/api/v1/payments/sse")
|
||||||
async def api_payments_sse(request: Request, wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_payments_sse(request: Request, wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||||
return EventSourceResponse(subscribe(request, wallet))
|
return EventSourceResponse(subscribe(request, wallet), ping=20, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/payments/{payment_hash}")
|
@core_app.get("/api/v1/payments/{payment_hash}")
|
||||||
@ -303,10 +304,8 @@ async def api_payment(payment_hash, wallet: WalletTypeInfo = Depends(get_key_typ
|
|||||||
except Exception:
|
except Exception:
|
||||||
return {"paid": False}, HTTPStatus.OK
|
return {"paid": False}, HTTPStatus.OK
|
||||||
|
|
||||||
return (
|
return {"paid": not payment.pending, "preimage": payment.preimage}
|
||||||
{"paid": not payment.pending, "preimage": payment.preimage},
|
|
||||||
HTTPStatus.OK,
|
|
||||||
)
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/lnurlscan/{code}", dependencies=[Depends(WalletInvoiceKeyChecker())])
|
@core_app.get("/api/v1/lnurlscan/{code}", dependencies=[Depends(WalletInvoiceKeyChecker())])
|
||||||
async def api_lnurlscan(code: str):
|
async def api_lnurlscan(code: str):
|
||||||
@ -326,7 +325,7 @@ async def api_lnurlscan(code: str):
|
|||||||
)
|
)
|
||||||
# will proceed with these values
|
# will proceed with these values
|
||||||
else:
|
else:
|
||||||
return {"message": "invalid lnurl"}, HTTPStatus.BAD_REQUEST
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="invalid lnurl")
|
||||||
|
|
||||||
# params is what will be returned to the client
|
# params is what will be returned to the client
|
||||||
params: Dict = {"domain": domain}
|
params: Dict = {"domain": domain}
|
||||||
@ -341,28 +340,25 @@ async def api_lnurlscan(code: str):
|
|||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(url, timeout=5)
|
r = await client.get(url, timeout=5)
|
||||||
if r.is_error:
|
if r.is_error:
|
||||||
return (
|
raise HTTPException(
|
||||||
{"domain": domain, "message": "failed to get parameters"},
|
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
detail={"domain": domain, "message": "failed to get parameters"}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(r.text)
|
data = json.loads(r.text)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
return (
|
raise HTTPException(
|
||||||
{
|
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
"domain": domain,
|
detail={"domain": domain, "message": f"got invalid response '{r.text[:200]}'"}
|
||||||
"message": f"got invalid response '{r.text[:200]}'",
|
|
||||||
},
|
|
||||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tag = data["tag"]
|
tag = data["tag"]
|
||||||
if tag == "channelRequest":
|
if tag == "channelRequest":
|
||||||
return (
|
raise HTTPException(
|
||||||
{"domain": domain, "kind": "channel", "message": "unsupported"},
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
HTTPStatus.BAD_REQUEST,
|
detail={"domain": domain, "kind": "channel", "message": "unsupported"}
|
||||||
)
|
)
|
||||||
|
|
||||||
params.update(**data)
|
params.update(**data)
|
||||||
@ -407,13 +403,13 @@ async def api_lnurlscan(code: str):
|
|||||||
|
|
||||||
params.update(commentAllowed=data.get("commentAllowed", 0))
|
params.update(commentAllowed=data.get("commentAllowed", 0))
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
return (
|
raise HTTPException(
|
||||||
{
|
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
"domain": domain,
|
detail={
|
||||||
"message": f"lnurl JSON response invalid: {exc}",
|
"domain": domain,
|
||||||
},
|
"message": f"lnurl JSON response invalid: {exc}",
|
||||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
})
|
||||||
)
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
@ -421,8 +417,9 @@ async def api_lnurlscan(code: str):
|
|||||||
async def api_perform_lnurlauth(callback: str):
|
async def api_perform_lnurlauth(callback: str):
|
||||||
err = await perform_lnurlauth(callback)
|
err = await perform_lnurlauth(callback)
|
||||||
if err:
|
if err:
|
||||||
return {"reason": err.reason}, HTTPStatus.SERVICE_UNAVAILABLE
|
raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason)
|
||||||
return "", HTTPStatus.OK
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@core_app.get("/api/v1/currencies")
|
@core_app.get("/api/v1/currencies")
|
||||||
|
@ -115,9 +115,9 @@ def api_validate_post_request(*, schema: dict):
|
|||||||
@wraps(view)
|
@wraps(view)
|
||||||
async def wrapped_view(**kwargs):
|
async def wrapped_view(**kwargs):
|
||||||
if "application/json" not in request.headers["Content-Type"]:
|
if "application/json" not in request.headers["Content-Type"]:
|
||||||
return (
|
raise HTTPException(
|
||||||
jsonify({"message": "Content-Type must be `application/json`."}),
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
HTTPStatus.BAD_REQUEST,
|
detail=jsonify({"message": "Content-Type must be `application/json`."})
|
||||||
)
|
)
|
||||||
|
|
||||||
v = Validator(schema)
|
v = Validator(schema)
|
||||||
@ -125,11 +125,12 @@ def api_validate_post_request(*, schema: dict):
|
|||||||
g().data = {key: data[key] for key in schema.keys() if key in data}
|
g().data = {key: data[key] for key in schema.keys() if key in data}
|
||||||
|
|
||||||
if not v.validate(g().data):
|
if not v.validate(g().data):
|
||||||
return (
|
raise HTTPException(
|
||||||
jsonify({"message": f"Errors in request data: {v.errors}"}),
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
HTTPStatus.BAD_REQUEST,
|
detail=jsonify({"message": f"Errors in request data: {v.errors}"})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return await view(**kwargs)
|
return await view(**kwargs)
|
||||||
|
|
||||||
return wrapped_view
|
return wrapped_view
|
||||||
@ -141,12 +142,19 @@ def check_user_exists(param: str = "usr"):
|
|||||||
def wrap(view):
|
def wrap(view):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
async def wrapped_view(**kwargs):
|
async def wrapped_view(**kwargs):
|
||||||
g().user = await get_user(request.args.get(param, type=str)) or abort(
|
g().user = await get_user(request.args.get(param, type=str))
|
||||||
HTTPStatus.NOT_FOUND, "User does not exist."
|
if not g().user:
|
||||||
)
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
detail="User does not exist."
|
||||||
|
)
|
||||||
|
|
||||||
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
|
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
|
||||||
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
|
detail="User not authorized."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
return await view(**kwargs)
|
return await view(**kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user