fix: FastAPify how data or exceptions are returned

FastAPI handles returning HTTPStatus codes differently than Quart did
This commit is contained in:
Stefan Stammberger 2021-09-10 21:40:14 +02:00
parent d9849d43d2
commit fa08177317
No known key found for this signature in database
GPG Key ID: 645FA807E935D9D5
2 changed files with 106 additions and 101 deletions

View File

@ -31,23 +31,18 @@ from ..tasks import api_invoice_listeners
@core_app.get("/api/v1/wallet") @core_app.get("/api/v1/wallet")
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
return ( return {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat},
{"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat},
HTTPStatus.OK,
)
@core_app.put("/api/v1/wallet/{new_name}") @core_app.put("/api/v1/wallet/{new_name}")
async def api_update_wallet(new_name: str, wallet: WalletTypeInfo = Depends(get_key_type)): async def api_update_wallet(new_name: str, wallet: WalletTypeInfo = Depends(get_key_type)):
await update_wallet(wallet.wallet.id, new_name) await update_wallet(wallet.wallet.id, new_name)
return ( return {
{ "id": wallet.wallet.id,
"id": wallet.wallet.id, "name": wallet.wallet.name,
"name": wallet.wallet.name, "balance": wallet.wallet.balance_msat,
"balance": wallet.wallet.balance_msat, }
},
HTTPStatus.OK,
)
@core_app.get("/api/v1/payments") @core_app.get("/api/v1/payments")
@ -92,7 +87,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
conn=conn, conn=conn,
) )
except InvoiceFailure as e: except InvoiceFailure as e:
return {"message": str(e)}, 520 raise HTTPException(status_code=520, detail=str(e))
except Exception as exc: except Exception as exc:
raise exc raise exc
@ -128,16 +123,15 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet):
except (httpx.ConnectError, httpx.RequestError): except (httpx.ConnectError, httpx.RequestError):
lnurl_response = False lnurl_response = False
return ( return {
{ "payment_hash": invoice.payment_hash,
"payment_hash": invoice.payment_hash, "payment_request": payment_request,
"payment_request": payment_request, # maintain backwards compatibility with API clients:
# maintain backwards compatibility with API clients: "checking_id": invoice.payment_hash,
"checking_id": invoice.payment_hash, "lnurl_response": lnurl_response,
"lnurl_response": lnurl_response, }
},
HTTPStatus.CREATED,
)
async def api_payments_pay_invoice(bolt11: str, wallet: Wallet): async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
@ -147,26 +141,34 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet):
payment_request=bolt11, payment_request=bolt11,
) )
except ValueError as e: except ValueError as e:
return {"message": str(e)}, HTTPStatus.BAD_REQUEST raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=str(e)
)
except PermissionError as e: except PermissionError as e:
return {"message": str(e)}, HTTPStatus.FORBIDDEN raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail=str(e)
)
except PaymentFailure as e: except PaymentFailure as e:
return {"message": str(e)}, 520 raise HTTPException(
status_code=520,
detail=str(e)
)
except Exception as exc: except Exception as exc:
raise exc raise exc
return ( return {
{ "payment_hash": payment_hash,
"payment_hash": payment_hash, # maintain backwards compatibility with API clients:
# maintain backwards compatibility with API clients: "checking_id": payment_hash,
"checking_id": payment_hash, }
},
HTTPStatus.CREATED,
)
@core_app.post("/api/v1/payments", deprecated=True, @core_app.post("/api/v1/payments", deprecated=True,
description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead") description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead",
status_code=HTTPStatus.CREATED)
async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True, async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True,
invoiceData: Optional[CreateInvoiceData] = Body(None), invoiceData: Optional[CreateInvoiceData] = Body(None),
bolt11: Optional[str] = Query(None)): bolt11: Optional[str] = Query(None)):
@ -201,33 +203,33 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
if r.is_error: if r.is_error:
raise httpx.ConnectError raise httpx.ConnectError
except (httpx.ConnectError, httpx.RequestError): except (httpx.ConnectError, httpx.RequestError):
return ( raise HTTPException(
{"message": f"Failed to connect to {domain}."}, status_code=HTTPStatus.BAD_REQUEST,
HTTPStatus.BAD_REQUEST, detail=f"Failed to connect to {domain}."
) )
params = json.loads(r.text) params = json.loads(r.text)
if params.get("status") == "ERROR": if params.get("status") == "ERROR":
return ({"message": f"{domain} said: '{params.get('reason', '')}'"}, raise HTTPException(
HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=f"{domain} said: '{params.get('reason', '')}'"
) )
invoice = bolt11.decode(params["pr"]) invoice = bolt11.decode(params["pr"])
if invoice.amount_msat != data.amount: if invoice.amount_msat != data.amount:
return ( raise HTTPException(
{ status_code=HTTPStatus.BAD_REQUEST,
"message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}." detail=f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}."
},
HTTPStatus.BAD_REQUEST,
) )
if invoice.description_hash != g().data["description_hash"]: if invoice.description_hash != g().data["description_hash"]:
return ( raise HTTPException(
{ status_code=HTTPStatus.BAD_REQUEST,
"message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}." detail=f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}."
},
HTTPStatus.BAD_REQUEST,
) )
extra = {} extra = {}
if params.get("successAction"): if params.get("successAction"):
@ -242,15 +244,13 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
extra=extra, extra=extra,
) )
return ( return {
{ "success_action": params.get("successAction"),
"success_action": params.get("successAction"), "payment_hash": payment_hash,
"payment_hash": payment_hash, # maintain backwards compatibility with API clients:
# maintain backwards compatibility with API clients: "checking_id": payment_hash,
"checking_id": payment_hash, }
},
HTTPStatus.CREATED,
)
async def subscribe(request: Request, wallet: Wallet): async def subscribe(request: Request, wallet: Wallet):
this_wallet_id = wallet.wallet.id this_wallet_id = wallet.wallet.id
@ -273,20 +273,21 @@ async def subscribe(request: Request, wallet: Wallet):
try: try:
while True: while True:
typ, data = await send_queue.get() typ, data = await send_queue.get()
message = [f"event: {typ}".encode("utf-8")]
if data: if data:
jdata = json.dumps(dict(data.dict(), pending=False)) jdata = json.dumps(dict(data.dict(), pending=False))
message.append(f"data: {jdata}".encode("utf-8"))
yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8")) # yield dict(id=1, event="this", data="1234")
# await asyncio.sleep(2)
yield dict(data=jdata, event=typ)
# yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8"))
except asyncio.CancelledError: except asyncio.CancelledError:
return return
@core_app.get("/api/v1/payments/sse") @core_app.get("/api/v1/payments/sse")
async def api_payments_sse(request: Request, wallet: WalletTypeInfo = Depends(get_key_type)): async def api_payments_sse(request: Request, wallet: WalletTypeInfo = Depends(get_key_type)):
return EventSourceResponse(subscribe(request, wallet)) return EventSourceResponse(subscribe(request, wallet), ping=20, media_type="text/event-stream")
@core_app.get("/api/v1/payments/{payment_hash}") @core_app.get("/api/v1/payments/{payment_hash}")
@ -303,10 +304,8 @@ async def api_payment(payment_hash, wallet: WalletTypeInfo = Depends(get_key_typ
except Exception: except Exception:
return {"paid": False}, HTTPStatus.OK return {"paid": False}, HTTPStatus.OK
return ( return {"paid": not payment.pending, "preimage": payment.preimage}
{"paid": not payment.pending, "preimage": payment.preimage},
HTTPStatus.OK,
)
@core_app.get("/api/v1/lnurlscan/{code}", dependencies=[Depends(WalletInvoiceKeyChecker())]) @core_app.get("/api/v1/lnurlscan/{code}", dependencies=[Depends(WalletInvoiceKeyChecker())])
async def api_lnurlscan(code: str): async def api_lnurlscan(code: str):
@ -326,7 +325,7 @@ async def api_lnurlscan(code: str):
) )
# will proceed with these values # will proceed with these values
else: else:
return {"message": "invalid lnurl"}, HTTPStatus.BAD_REQUEST raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="invalid lnurl")
# params is what will be returned to the client # params is what will be returned to the client
params: Dict = {"domain": domain} params: Dict = {"domain": domain}
@ -341,28 +340,25 @@ async def api_lnurlscan(code: str):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.get(url, timeout=5) r = await client.get(url, timeout=5)
if r.is_error: if r.is_error:
return ( raise HTTPException(
{"domain": domain, "message": "failed to get parameters"}, status_code=HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.SERVICE_UNAVAILABLE, detail={"domain": domain, "message": "failed to get parameters"}
) )
try: try:
data = json.loads(r.text) data = json.loads(r.text)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
return ( raise HTTPException(
{ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
"domain": domain, detail={"domain": domain, "message": f"got invalid response '{r.text[:200]}'"}
"message": f"got invalid response '{r.text[:200]}'",
},
HTTPStatus.SERVICE_UNAVAILABLE,
) )
try: try:
tag = data["tag"] tag = data["tag"]
if tag == "channelRequest": if tag == "channelRequest":
return ( raise HTTPException(
{"domain": domain, "kind": "channel", "message": "unsupported"}, status_code=HTTPStatus.BAD_REQUEST,
HTTPStatus.BAD_REQUEST, detail={"domain": domain, "kind": "channel", "message": "unsupported"}
) )
params.update(**data) params.update(**data)
@ -407,13 +403,13 @@ async def api_lnurlscan(code: str):
params.update(commentAllowed=data.get("commentAllowed", 0)) params.update(commentAllowed=data.get("commentAllowed", 0))
except KeyError as exc: except KeyError as exc:
return ( raise HTTPException(
{ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
"domain": domain, detail={
"message": f"lnurl JSON response invalid: {exc}", "domain": domain,
}, "message": f"lnurl JSON response invalid: {exc}",
HTTPStatus.SERVICE_UNAVAILABLE, })
)
return params return params
@ -421,8 +417,9 @@ async def api_lnurlscan(code: str):
async def api_perform_lnurlauth(callback: str): async def api_perform_lnurlauth(callback: str):
err = await perform_lnurlauth(callback) err = await perform_lnurlauth(callback)
if err: if err:
return {"reason": err.reason}, HTTPStatus.SERVICE_UNAVAILABLE raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason)
return "", HTTPStatus.OK
return ""
@core_app.get("/api/v1/currencies") @core_app.get("/api/v1/currencies")

View File

@ -115,9 +115,9 @@ def api_validate_post_request(*, schema: dict):
@wraps(view) @wraps(view)
async def wrapped_view(**kwargs): async def wrapped_view(**kwargs):
if "application/json" not in request.headers["Content-Type"]: if "application/json" not in request.headers["Content-Type"]:
return ( raise HTTPException(
jsonify({"message": "Content-Type must be `application/json`."}), status_code=HTTPStatus.BAD_REQUEST,
HTTPStatus.BAD_REQUEST, detail=jsonify({"message": "Content-Type must be `application/json`."})
) )
v = Validator(schema) v = Validator(schema)
@ -125,11 +125,12 @@ def api_validate_post_request(*, schema: dict):
g().data = {key: data[key] for key in schema.keys() if key in data} g().data = {key: data[key] for key in schema.keys() if key in data}
if not v.validate(g().data): if not v.validate(g().data):
return ( raise HTTPException(
jsonify({"message": f"Errors in request data: {v.errors}"}), status_code=HTTPStatus.BAD_REQUEST,
HTTPStatus.BAD_REQUEST, detail=jsonify({"message": f"Errors in request data: {v.errors}"})
) )
return await view(**kwargs) return await view(**kwargs)
return wrapped_view return wrapped_view
@ -141,12 +142,19 @@ def check_user_exists(param: str = "usr"):
def wrap(view): def wrap(view):
@wraps(view) @wraps(view)
async def wrapped_view(**kwargs): async def wrapped_view(**kwargs):
g().user = await get_user(request.args.get(param, type=str)) or abort( g().user = await get_user(request.args.get(param, type=str))
HTTPStatus.NOT_FOUND, "User does not exist." if not g().user:
) raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="User does not exist."
)
if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="User not authorized."
)
return await view(**kwargs) return await view(**kwargs)