From fa0817731730353fc1868ede1e15d9dbdb28be80 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Fri, 10 Sep 2021 21:40:14 +0200 Subject: [PATCH 1/7] fix: FastAPify how data or exceptions are returned FastAPI handles returning HTTPStatus codes differently than Quart did --- lnbits/core/views/api.py | 179 +++++++++++++++++++-------------------- lnbits/decorators.py | 28 +++--- 2 files changed, 106 insertions(+), 101 deletions(-) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 12c3e0a79..9aea5d80e 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -31,23 +31,18 @@ from ..tasks import api_invoice_listeners @core_app.get("/api/v1/wallet") async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): - return ( - {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat}, - HTTPStatus.OK, - ) - + return {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat}, + @core_app.put("/api/v1/wallet/{new_name}") async def api_update_wallet(new_name: str, wallet: WalletTypeInfo = Depends(get_key_type)): await update_wallet(wallet.wallet.id, new_name) - return ( - { - "id": wallet.wallet.id, - "name": wallet.wallet.name, - "balance": wallet.wallet.balance_msat, - }, - HTTPStatus.OK, - ) + return { + "id": wallet.wallet.id, + "name": wallet.wallet.name, + "balance": wallet.wallet.balance_msat, + } + @core_app.get("/api/v1/payments") @@ -92,7 +87,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet): conn=conn, ) except InvoiceFailure as e: - return {"message": str(e)}, 520 + raise HTTPException(status_code=520, detail=str(e)) except Exception as exc: raise exc @@ -128,16 +123,15 @@ async def api_payments_create_invoice(data: CreateInvoiceData, wallet: Wallet): except (httpx.ConnectError, httpx.RequestError): lnurl_response = False - return ( - { - "payment_hash": invoice.payment_hash, - "payment_request": payment_request, - # maintain backwards compatibility with API clients: - "checking_id": invoice.payment_hash, - "lnurl_response": lnurl_response, - }, - HTTPStatus.CREATED, - ) + return { + "payment_hash": invoice.payment_hash, + "payment_request": payment_request, + # maintain backwards compatibility with API clients: + "checking_id": invoice.payment_hash, + "lnurl_response": lnurl_response, + } + + async def api_payments_pay_invoice(bolt11: str, wallet: Wallet): @@ -147,26 +141,34 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet): payment_request=bolt11, ) except ValueError as e: - return {"message": str(e)}, HTTPStatus.BAD_REQUEST + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e) + ) except PermissionError as e: - return {"message": str(e)}, HTTPStatus.FORBIDDEN + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail=str(e) + ) except PaymentFailure as e: - return {"message": str(e)}, 520 + raise HTTPException( + status_code=520, + detail=str(e) + ) except Exception as exc: raise exc - return ( - { - "payment_hash": payment_hash, - # maintain backwards compatibility with API clients: - "checking_id": payment_hash, - }, - HTTPStatus.CREATED, - ) + return { + "payment_hash": payment_hash, + # maintain backwards compatibility with API clients: + "checking_id": payment_hash, + } + @core_app.post("/api/v1/payments", deprecated=True, - description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead") + description="DEPRECATED. Use /api/v2/TBD and /api/v2/TBD instead", + status_code=HTTPStatus.CREATED) async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True, invoiceData: Optional[CreateInvoiceData] = Body(None), bolt11: Optional[str] = Query(None)): @@ -201,32 +203,32 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): if r.is_error: raise httpx.ConnectError except (httpx.ConnectError, httpx.RequestError): - return ( - {"message": f"Failed to connect to {domain}."}, - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Failed to connect to {domain}." ) params = json.loads(r.text) if params.get("status") == "ERROR": - return ({"message": f"{domain} said: '{params.get('reason', '')}'"}, - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"{domain} said: '{params.get('reason', '')}'" ) + invoice = bolt11.decode(params["pr"]) if invoice.amount_msat != data.amount: - return ( - { - "message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}." - }, - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}." ) + if invoice.description_hash != g().data["description_hash"]: - return ( - { - "message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}." - }, - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}." ) + extra = {} @@ -242,15 +244,13 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): extra=extra, ) - return ( - { - "success_action": params.get("successAction"), - "payment_hash": payment_hash, - # maintain backwards compatibility with API clients: - "checking_id": payment_hash, - }, - HTTPStatus.CREATED, - ) + return { + "success_action": params.get("successAction"), + "payment_hash": payment_hash, + # maintain backwards compatibility with API clients: + "checking_id": payment_hash, + } + async def subscribe(request: Request, wallet: Wallet): this_wallet_id = wallet.wallet.id @@ -273,20 +273,21 @@ async def subscribe(request: Request, wallet: Wallet): try: while True: typ, data = await send_queue.get() - message = [f"event: {typ}".encode("utf-8")] if data: jdata = json.dumps(dict(data.dict(), pending=False)) - message.append(f"data: {jdata}".encode("utf-8")) - - yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8")) + + # yield dict(id=1, event="this", data="1234") + # await asyncio.sleep(2) + yield dict(data=jdata, event=typ) + # yield dict(data=jdata.encode("utf-8"), event=typ.encode("utf-8")) except asyncio.CancelledError: return @core_app.get("/api/v1/payments/sse") async def api_payments_sse(request: Request, wallet: WalletTypeInfo = Depends(get_key_type)): - return EventSourceResponse(subscribe(request, wallet)) + return EventSourceResponse(subscribe(request, wallet), ping=20, media_type="text/event-stream") @core_app.get("/api/v1/payments/{payment_hash}") @@ -303,10 +304,8 @@ async def api_payment(payment_hash, wallet: WalletTypeInfo = Depends(get_key_typ except Exception: return {"paid": False}, HTTPStatus.OK - return ( - {"paid": not payment.pending, "preimage": payment.preimage}, - HTTPStatus.OK, - ) + return {"paid": not payment.pending, "preimage": payment.preimage} + @core_app.get("/api/v1/lnurlscan/{code}", dependencies=[Depends(WalletInvoiceKeyChecker())]) async def api_lnurlscan(code: str): @@ -326,7 +325,7 @@ async def api_lnurlscan(code: str): ) # will proceed with these values else: - return {"message": "invalid lnurl"}, HTTPStatus.BAD_REQUEST + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="invalid lnurl") # params is what will be returned to the client params: Dict = {"domain": domain} @@ -341,28 +340,25 @@ async def api_lnurlscan(code: str): async with httpx.AsyncClient() as client: r = await client.get(url, timeout=5) if r.is_error: - return ( - {"domain": domain, "message": "failed to get parameters"}, - HTTPStatus.SERVICE_UNAVAILABLE, + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail={"domain": domain, "message": "failed to get parameters"} ) try: data = json.loads(r.text) except json.decoder.JSONDecodeError: - return ( - { - "domain": domain, - "message": f"got invalid response '{r.text[:200]}'", - }, - HTTPStatus.SERVICE_UNAVAILABLE, + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail={"domain": domain, "message": f"got invalid response '{r.text[:200]}'"} ) try: tag = data["tag"] if tag == "channelRequest": - return ( - {"domain": domain, "kind": "channel", "message": "unsupported"}, - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail={"domain": domain, "kind": "channel", "message": "unsupported"} ) params.update(**data) @@ -407,13 +403,13 @@ async def api_lnurlscan(code: str): params.update(commentAllowed=data.get("commentAllowed", 0)) except KeyError as exc: - return ( - { - "domain": domain, - "message": f"lnurl JSON response invalid: {exc}", - }, - HTTPStatus.SERVICE_UNAVAILABLE, - ) + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail={ + "domain": domain, + "message": f"lnurl JSON response invalid: {exc}", + }) + return params @@ -421,8 +417,9 @@ async def api_lnurlscan(code: str): async def api_perform_lnurlauth(callback: str): err = await perform_lnurlauth(callback) if err: - return {"reason": err.reason}, HTTPStatus.SERVICE_UNAVAILABLE - return "", HTTPStatus.OK + raise HTTPException(status_code=HTTPStatus.SERVICE_UNAVAILABLE, detail=err.reason) + + return "" @core_app.get("/api/v1/currencies") diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 372d3955b..a1ced0ca1 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -115,9 +115,9 @@ def api_validate_post_request(*, schema: dict): @wraps(view) async def wrapped_view(**kwargs): if "application/json" not in request.headers["Content-Type"]: - return ( - jsonify({"message": "Content-Type must be `application/json`."}), - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=jsonify({"message": "Content-Type must be `application/json`."}) ) v = Validator(schema) @@ -125,10 +125,11 @@ def api_validate_post_request(*, schema: dict): g().data = {key: data[key] for key in schema.keys() if key in data} if not v.validate(g().data): - return ( - jsonify({"message": f"Errors in request data: {v.errors}"}), - HTTPStatus.BAD_REQUEST, + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=jsonify({"message": f"Errors in request data: {v.errors}"}) ) + return await view(**kwargs) @@ -141,12 +142,19 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): - g().user = await get_user(request.args.get(param, type=str)) or abort( - HTTPStatus.NOT_FOUND, "User does not exist." - ) + g().user = await get_user(request.args.get(param, type=str)) + if not g().user: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="User does not exist." + ) if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: - abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="User not authorized." + ) + return await view(**kwargs) From d8d8c6b454b63c503bbd5fb8b59db56ab202d736 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Fri, 10 Sep 2021 21:41:37 +0200 Subject: [PATCH 2/7] docs: add a FastAPI transition documentation --- docs/guide/fastapi_transition.md | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 docs/guide/fastapi_transition.md diff --git a/docs/guide/fastapi_transition.md b/docs/guide/fastapi_transition.md new file mode 100644 index 000000000..efccf317e --- /dev/null +++ b/docs/guide/fastapi_transition.md @@ -0,0 +1,55 @@ +## Returning data from API calls +**old:** +```python +return ( + { + "id": wallet.wallet.id, + "name": wallet.wallet.name, + "balance": wallet.wallet.balance_msat + }, + HTTPStatus.OK, +) +``` +FastAPI returns `HTTPStatus.OK` by default id no Exception is raised + +**new:** +```python +return { + "id": wallet.wallet.id, + "name": wallet.wallet.name, + "balance": wallet.wallet.balance_msat +} +``` + +To change the default HTTPStatus, add it to the path decorator +```python +@core_app.post("/api/v1/payments", status_code=HTTPStatus.CREATED) +async def payments(): + pass +``` + +## Raise exceptions +**old:** +```python +return ( + {"message": f"Failed to connect to {domain}."}, + HTTPStatus.BAD_REQUEST, +) +``` + +**new:** + +Raise an exception to return a status code other than the default status code. +```python +raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Failed to connect to {domain}." +) +``` +## Possible optimizations +### Use Redis as a cache server +Instead of hitting the database over and over again, we can store a short lived object in [Redis](https://redis.io) for an arbitrary key. +Example: +* Get transactions for a wallet ID +* User data for a user id +* Wallet data for a Admin / Invoice key \ No newline at end of file From 63d02426859065722856bce99403a95fa16ca97f Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sat, 11 Sep 2021 11:02:48 +0200 Subject: [PATCH 3/7] fix: more return types --- lnbits/core/views/api.py | 6 +++--- lnbits/core/views/generic.py | 2 +- lnbits/core/views/public_api.py | 23 ++++++++++++++++------- lnbits/tasks.py | 4 +++- lnbits/wallets/lnpay.py | 6 ++++-- lnbits/wallets/opennode.py | 10 +++++++--- 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 9aea5d80e..d28d5eb2b 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -295,14 +295,14 @@ async def api_payment(payment_hash, wallet: WalletTypeInfo = Depends(get_key_typ payment = await wallet.wallet.get_payment(payment_hash) if not payment: - return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND + return {"message": "Payment does not exist."} elif not payment.pending: - return {"paid": True, "preimage": payment.preimage}, HTTPStatus.OK + return {"paid": True, "preimage": payment.preimage} try: await payment.check_pending() except Exception: - return {"paid": False}, HTTPStatus.OK + return {"paid": False} return {"paid": not payment.pending, "preimage": payment.preimage} diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 2e041802f..d7c950509 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -204,7 +204,7 @@ async def lnurlwallet(request: Request): async def manifest(usr: str): user = await get_user(usr) if not user: - return "", HTTPStatus.NOT_FOUND + raise HTTPException(status_code=HTTPStatus.NOT_FOUND) return { "short_name": "LNbits", diff --git a/lnbits/core/views/public_api.py b/lnbits/core/views/public_api.py index 027585219..70f949dc7 100644 --- a/lnbits/core/views/public_api.py +++ b/lnbits/core/views/public_api.py @@ -1,7 +1,7 @@ import asyncio import datetime from http import HTTPStatus - +from fastapi import HTTPException from lnbits import bolt11 from .. import core_app @@ -14,17 +14,23 @@ async def api_public_payment_longpolling(payment_hash): payment = await get_standalone_payment(payment_hash) if not payment: - return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Payment does not exist." + ) elif not payment.pending: - return {"status": "paid"}, HTTPStatus.OK + return {"status": "paid"} try: invoice = bolt11.decode(payment.bolt11) expiration = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry) if expiration < datetime.datetime.now(): - return {"status": "expired"}, HTTPStatus.OK + return {"status": "expired"} except: - return {"message": "Invalid bolt11 invoice."}, HTTPStatus.BAD_REQUEST + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Invalid bolt11 invoice." + ) payment_queue = asyncio.Queue(0) @@ -37,7 +43,7 @@ async def api_public_payment_longpolling(payment_hash): async for payment in payment_queue.get(): if payment.payment_hash == payment_hash: nonlocal response - response = ({"status": "paid"}, HTTPStatus.OK) + response = {"status": "paid"} cancel_scope.cancel() async def timeouter(cancel_scope): @@ -51,4 +57,7 @@ async def api_public_payment_longpolling(payment_hash): if response: return response else: - return {"message": "timeout"}, HTTPStatus.REQUEST_TIMEOUT + raise HTTPException( + status_code=HTTPStatus.REQUEST_TIMEOUT, + detail="timeout" + ) diff --git a/lnbits/tasks.py b/lnbits/tasks.py index ab1ebc46b..4e73a0af6 100644 --- a/lnbits/tasks.py +++ b/lnbits/tasks.py @@ -4,6 +4,8 @@ import traceback from http import HTTPStatus from typing import List, Callable +from fastapi.exceptions import HTTPException + from lnbits.settings import WALLET from lnbits.core.crud import ( get_payments, @@ -61,7 +63,7 @@ async def webhook_handler(): handler = getattr(WALLET, "webhook_listener", None) if handler: return await handler() - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) internal_invoice_queue = asyncio.Queue(0) diff --git a/lnbits/wallets/lnpay.py b/lnbits/wallets/lnpay.py index 305400df4..ab8e0d817 100644 --- a/lnbits/wallets/lnpay.py +++ b/lnbits/wallets/lnpay.py @@ -1,5 +1,6 @@ import json import asyncio +from fastapi.exceptions import HTTPException import httpx from os import getenv from http import HTTPStatus @@ -133,7 +134,7 @@ class LNPayWallet(Wallet): or "event" not in data or data["event"].get("name") != "wallet_receive" ): - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) lntx_id = data["data"]["wtx"]["lnTx"]["id"] async with httpx.AsyncClient() as client: @@ -145,4 +146,5 @@ class LNPayWallet(Wallet): if data["settled"]: await self.queue.put(lntx_id) - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) + diff --git a/lnbits/wallets/opennode.py b/lnbits/wallets/opennode.py index d955cc0ba..ddc2849eb 100644 --- a/lnbits/wallets/opennode.py +++ b/lnbits/wallets/opennode.py @@ -1,4 +1,6 @@ import asyncio + +from fastapi.exceptions import HTTPException from lnbits.helpers import url_for import hmac import httpx @@ -133,14 +135,16 @@ class OpenNodeWallet(Wallet): async def webhook_listener(self): data = await request.form if "status" not in data or data["status"] != "paid": - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) + charge_id = data["id"] x = hmac.new(self.auth["Authorization"].encode("ascii"), digestmod="sha256") x.update(charge_id.encode("ascii")) if x.hexdigest() != data["hashed_order"]: print("invalid webhook, not from opennode") - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) await self.queue.put(charge_id) - return "", HTTPStatus.NO_CONTENT + raise HTTPException(status_code=HTTPStatus.NO_CONTENT) + From 9e7666826954544739fd7b9f4470df0217674231 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sat, 11 Sep 2021 11:47:05 +0200 Subject: [PATCH 4/7] fix: send payments via Wallet UI --- lnbits/core/views/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index d28d5eb2b..8d086bd32 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -31,7 +31,7 @@ from ..tasks import api_invoice_listeners @core_app.get("/api/v1/wallet") async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): - return {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat}, + return {"id": wallet.wallet.id, "name": wallet.wallet.name, "balance": wallet.wallet.balance_msat} @core_app.put("/api/v1/wallet/{new_name}") @@ -171,7 +171,7 @@ async def api_payments_pay_invoice(bolt11: str, wallet: Wallet): status_code=HTTPStatus.CREATED) async def api_payments_create(wallet: WalletTypeInfo = Depends(get_key_type), out: bool = True, invoiceData: Optional[CreateInvoiceData] = Body(None), - bolt11: Optional[str] = Query(None)): + bolt11: Optional[str] = Body(None)): if wallet.wallet_type < 0 or wallet.wallet_type > 2: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Key is invalid") From c2551bd76597726917bf638442b85fc22bbe1c73 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sat, 11 Sep 2021 12:28:29 +0200 Subject: [PATCH 5/7] docs: add another old way of raising exceptions --- docs/guide/fastapi_transition.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/guide/fastapi_transition.md b/docs/guide/fastapi_transition.md index efccf317e..d4688154c 100644 --- a/docs/guide/fastapi_transition.md +++ b/docs/guide/fastapi_transition.md @@ -35,6 +35,8 @@ return ( {"message": f"Failed to connect to {domain}."}, HTTPStatus.BAD_REQUEST, ) +# or the Quart way via abort function +abort(HTTPStatus.INTERNAL_SERVER_ERROR, "Could not process withdraw LNURL.") ``` **new:** From 7b69852accdd9c4bd7bbde950c7f9dc2af279c27 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sat, 11 Sep 2021 15:18:09 +0200 Subject: [PATCH 6/7] fix: make check_user_exists() work with FastAPI --- docs/guide/fastapi_transition.md | 51 ++++++++++++++++++++++++++++++++ lnbits/decorators.py | 40 ++++++++++--------------- 2 files changed, 66 insertions(+), 25 deletions(-) diff --git a/docs/guide/fastapi_transition.md b/docs/guide/fastapi_transition.md index d4688154c..6ae179e24 100644 --- a/docs/guide/fastapi_transition.md +++ b/docs/guide/fastapi_transition.md @@ -1,3 +1,21 @@ +## Check if a user exists and access user object +**old:** +```python +# decorators +@check_user_exists() +async def do_routing_stuff(): + pass +``` + +**new:** +If user doesn't exist, `Depends(check_user_exists)` will raise an exception. +If user exists, `user` will be the user object +```python +# depends calls +@core_html_routes.get("/my_route") +async def extensions(user: User = Depends(check_user_exists)): + pass +``` ## Returning data from API calls **old:** ```python @@ -48,6 +66,39 @@ raise HTTPException( detail=f"Failed to connect to {domain}." ) ``` + +## Extensions +**old:** +```python +from quart import Blueprint + +amilk_ext: Blueprint = Blueprint( + "amilk", __name__, static_folder="static", template_folder="templates" +) +``` + +**new:** +```python +from fastapi import APIRouter +from lnbits.jinja2_templating import Jinja2Templates +from lnbits.helpers import template_renderer +from fastapi.staticfiles import StaticFiles + +offlineshop_ext: APIRouter = APIRouter( + prefix="/Extension", + tags=["Offlineshop"] +) + +offlineshop_ext.mount( + "lnbits/extensions/offlineshop/static", + StaticFiles("lnbits/extensions/offlineshop/static") +) + +offlineshop_rndr = template_renderer([ + "lnbits/extensions/offlineshop/templates", +]) +``` + ## Possible optimizations ### Use Redis as a cache server Instead of hitting the database over and over again, we can store a short lived object in [Redis](https://redis.io) for an arbitrary key. diff --git a/lnbits/decorators.py b/lnbits/decorators.py index a1ced0ca1..ff42d0fd5 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -2,7 +2,8 @@ from functools import wraps from http import HTTPStatus from fastapi.security import api_key -from lnbits.core.models import Wallet +from pydantic.types import UUID4 +from lnbits.core.models import User, Wallet from typing import List, Union from uuid import UUID @@ -138,29 +139,18 @@ def api_validate_post_request(*, schema: dict): return wrap -def check_user_exists(param: str = "usr"): - def wrap(view): - @wraps(view) - async def wrapped_view(**kwargs): - g().user = await get_user(request.args.get(param, type=str)) - if not g().user: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="User does not exist." - ) - - if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="User not authorized." - ) - - - return await view(**kwargs) - - return wrapped_view - - return wrap - +async def check_user_exists(usr: UUID4) -> User: + g().user = await get_user(usr.hex) + if not g().user: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="User does not exist." + ) + if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="User not authorized." + ) + return g().user From 3bae5c92c2004ce512cac5bc0f3a78a89b2be8fa Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sat, 11 Sep 2021 20:44:22 +0200 Subject: [PATCH 7/7] fix: /extensions endpoint --- lnbits/core/views/generic.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index d7c950509..b0055af8d 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -15,6 +15,8 @@ from starlette.responses import HTMLResponse from lnbits.core import db from lnbits.helpers import template_renderer, url_for from lnbits.requestvars import g +from lnbits.core.models import User +from lnbits.decorators import check_user_exists from lnbits.settings import (LNBITS_ALLOWED_USERS, LNBITS_SITE_TITLE, SERVICE_FEE) @@ -35,10 +37,13 @@ async def home(request: Request, lightning: str = None): return template_renderer().TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) -@core_html_routes.get("/extensions") -# @validate_uuids(["usr"], required=True) -# @check_user_exists() -async def extensions(request: Request, enable: str, disable: str): +@core_html_routes.get("/extensions", name="core.extensions") +async def extensions( + request: Request, + user: User = Depends(check_user_exists), + enable: str= Query(None), + disable: str = Query(None) + ): extension_to_enable = enable extension_to_disable = disable @@ -47,13 +52,18 @@ async def extensions(request: Request, enable: str, disable: str): if extension_to_enable: await update_user_extension( - user_id=g.user.id, extension=extension_to_enable, active=True + user_id=user.id, extension=extension_to_enable, active=True ) elif extension_to_disable: await update_user_extension( - user_id=g.user.id, extension=extension_to_disable, active=False + user_id=user.id, extension=extension_to_disable, active=False ) - return template_renderer().TemplateResponse("core/extensions.html", {"request": request, "user": get_user(g.user.id)}) + + # Update user as his extensions have been updated + if extension_to_enable or extension_to_disable: + user = await get_user(user.id) + + return template_renderer().TemplateResponse("core/extensions.html", {"request": request, "user": user.dict()}) @core_html_routes.get("/wallet", response_class=HTMLResponse)