From 570e8d28dd73f05259ee1fbec54936d658e3a884 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Fri, 15 Oct 2021 19:55:24 +0200 Subject: [PATCH] fix: lndhub auth handling --- lnbits/decorators.py | 20 ++++++-------- lnbits/extensions/lndhub/decorators.py | 37 ++++++++++---------------- lnbits/extensions/lndhub/views_api.py | 14 +++++----- 3 files changed, 29 insertions(+), 42 deletions(-) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 056ed8044..5962c11e4 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -15,6 +15,7 @@ from fastapi.params import Security from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security import OAuth2PasswordBearer from fastapi.security.base import SecurityBase +from fastapi import status from starlette.requests import Request from lnbits.core.crud import get_user, get_wallet_for_key @@ -84,25 +85,20 @@ class WalletTypeInfo(): self.wallet_type = wallet_type self.wallet = wallet -api_key_header_xapi = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") -api_key_header_auth = APIKeyHeader(name="AUTHORIZATION", auto_error=False, description="Admin or Invoice key for wallet API's") +api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") async def get_key_type(r: Request, - api_key_header_auth: str = Security(api_key_header_auth), - api_key_header: str = Security(api_key_header_auth), + api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query)) -> WalletTypeInfo: # 0: admin # 1: invoice # 2: invalid - # print("TOKEN", b64decode(token).decode("utf-8").split(":")) - - if api_key_header_xapi: - token = api_key_header_xapi - elif api_key_header_auth: - _, token = b64decode(api_key_header_auth).decode("utf-8").split(":") - elif api_key_query: - token = api_key_query + + if not api_key_header and not api_key_query: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + token = api_key_header if api_key_header else api_key_query + try: checker = WalletAdminKeyChecker(api_key=token) await checker.__call__(r) diff --git a/lnbits/extensions/lndhub/decorators.py b/lnbits/extensions/lndhub/decorators.py index 7526135ba..74d5fa764 100644 --- a/lnbits/extensions/lndhub/decorators.py +++ b/lnbits/extensions/lndhub/decorators.py @@ -1,33 +1,24 @@ from base64 import b64decode from functools import wraps +from fastapi.param_functions import Security + +from fastapi.security.api_key import APIKeyHeader from lnbits.core.crud import get_wallet_for_key -from fastapi import Request -from http import HTTPStatus +from fastapi import Request, status from starlette.exceptions import HTTPException -from starlette.responses import HTMLResponse, JSONResponse # type: ignore +from starlette.responses import HTMLResponse, JSONResponse + +from lnbits.decorators import WalletTypeInfo, get_key_type # type: ignore -def check_wallet(requires_admin=False): - def wrap(view): - @wraps(view) - async def wrapped_view(request: Request, **kwargs): - token = request.headers["Authorization"].split("Bearer ")[1] - key_type, key = b64decode(token).decode("utf-8").split(":") +api_key_header_auth = APIKeyHeader(name="AUTHORIZATION", auto_error=False, description="Admin or Invoice key for LNDHub API's") +async def check_wallet(r: Request, api_key_header_auth: str = Security(api_key_header_auth)) -> WalletTypeInfo: + if not api_key_header_auth: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - if requires_admin and key_type != "admin": - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, - detail="insufficient permissions", - ) - g.wallet = await get_wallet_for_key(key, key_type) - if not g.wallet: - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, - detail="insufficient permissions", - ) - return await view(**kwargs) + t = api_key_header_auth.split(" ")[1] + _, token = b64decode(t).decode("utf-8").split(":") - return wrapped_view + return await get_key_type(r, api_key_header=token) - return wrap diff --git a/lnbits/extensions/lndhub/views_api.py b/lnbits/extensions/lndhub/views_api.py index 650c01323..edbb3cf3b 100644 --- a/lnbits/extensions/lndhub/views_api.py +++ b/lnbits/extensions/lndhub/views_api.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from lnbits.core.services import pay_invoice, create_invoice from lnbits.core.crud import get_payments, delete_expired_invoices -from lnbits.decorators import api_validate_post_request, WalletTypeInfo, get_key_type +from lnbits.decorators import WalletTypeInfo from lnbits.settings import WALLET from lnbits import bolt11 @@ -52,7 +52,7 @@ class AddInvoice(BaseModel): @lndhub_ext.post("/ext/addinvoice") async def lndhub_addinvoice( data: AddInvoice, - wallet: WalletTypeInfo = Depends(get_key_type) + wallet: WalletTypeInfo = Depends(check_wallet) ): try: _, pr = await create_invoice( @@ -79,7 +79,7 @@ async def lndhub_addinvoice( @lndhub_ext.post("/ext/payinvoice") async def lndhub_payinvoice( - wallet: WalletTypeInfo = Depends(get_key_type), invoice: str = Query(None) + wallet: WalletTypeInfo = Depends(check_wallet), invoice: str = Query(None) ): try: await pay_invoice( @@ -112,7 +112,7 @@ async def lndhub_payinvoice( @lndhub_ext.get("/ext/balance") # @check_wallet() async def lndhub_balance( - wallet: WalletTypeInfo = Depends(get_key_type), + wallet: WalletTypeInfo = Depends(check_wallet), ): return {"BTC": {"AvailableBalance": wallet.wallet.balance}} @@ -120,7 +120,7 @@ async def lndhub_balance( @lndhub_ext.get("/ext/gettxs") # @check_wallet() async def lndhub_gettxs( - wallet: WalletTypeInfo = Depends(get_key_type), limit: int = Query(0, ge=0, lt=200) + wallet: WalletTypeInfo = Depends(check_wallet), limit: int = Query(0, ge=0, lt=200) ): print("WALLET", wallet) for payment in await get_payments( @@ -161,7 +161,7 @@ async def lndhub_gettxs( @lndhub_ext.get("/ext/getuserinvoices") async def lndhub_getuserinvoices( - wallet: WalletTypeInfo = Depends(get_key_type), limit: int = Query(0, ge=0, lt=200) + wallet: WalletTypeInfo = Depends(check_wallet), limit: int = Query(0, ge=0, lt=200) ): await delete_expired_invoices() for invoice in await get_payments( @@ -203,7 +203,7 @@ async def lndhub_getuserinvoices( @lndhub_ext.get("/ext/getbtc") -async def lndhub_getbtc(wallet: WalletTypeInfo = Depends(get_key_type)): +async def lndhub_getbtc(wallet: WalletTypeInfo = Depends(check_wallet)): "load an address for incoming onchain btc" return []