From 6730c6ed67371ea21fdfff4e7f12e1eea7cab34d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Mon, 13 May 2024 16:26:25 +0200 Subject: [PATCH] refactor: fix duplicate keychecker (#2339) * refactor: fix duplicate keychecker - refactor KeyChecker to be more approachable - only 1 sql query needed even if you use `get_key_type` - rename `WalletType` to `KeyType` wallet type was misleading fix test sorting * fixup! * revert 404 --- lnbits/core/crud.py | 5 - lnbits/core/models.py | 4 +- lnbits/core/views/api.py | 8 +- lnbits/core/views/payment_api.py | 4 +- lnbits/core/views/wallet_api.py | 4 +- lnbits/decorators.py | 239 ++++++++++--------------------- tests/api/test_api.py | 14 ++ 7 files changed, 101 insertions(+), 177 deletions(-) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 18018fd45..dedb9911a 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -8,7 +8,6 @@ import shortuuid from passlib.context import CryptContext from lnbits.core.db import db -from lnbits.core.models import WalletType from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page from lnbits.extension_manager import InstallableExtension from lnbits.settings import ( @@ -628,7 +627,6 @@ async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[W async def get_wallet_for_key( key: str, - key_type: WalletType = WalletType.invoice, conn: Optional[Connection] = None, ) -> Optional[Wallet]: row = await (conn or db).fetchone( @@ -643,9 +641,6 @@ async def get_wallet_for_key( if not row: return None - if key_type == WalletType.admin and row["adminkey"] != key: - return None - return Wallet(**row) diff --git a/lnbits/core/models.py b/lnbits/core/models.py index ed8a800a4..6a2f272c3 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -68,7 +68,7 @@ class Wallet(BaseWallet): return await get_standalone_payment(payment_hash) -class WalletType(Enum): +class KeyType(Enum): admin = 0 invoice = 1 invalid = 2 @@ -80,7 +80,7 @@ class WalletType(Enum): @dataclass class WalletTypeInfo: - wallet_type: WalletType + key_type: KeyType wallet: Wallet diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index b78995b19..2856151ae 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -25,8 +25,8 @@ from lnbits.core.models import ( from lnbits.decorators import ( WalletTypeInfo, check_user_exists, - get_key_type, require_admin_key, + require_invoice_key, ) from lnbits.lnurl import decode as lnurl_decode from lnbits.settings import settings @@ -67,7 +67,7 @@ async def api_wallets(user: User = Depends(check_user_exists)) -> List[BaseWalle async def api_create_account(data: CreateWallet) -> Wallet: if not settings.new_accounts_allowed: raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.FORBIDDEN, detail="Account creation is disabled.", ) account = await create_account() @@ -75,7 +75,9 @@ async def api_create_account(data: CreateWallet) -> Wallet: @api_router.get("/api/v1/lnurlscan/{code}") -async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type)): +async def api_lnurlscan( + code: str, wallet: WalletTypeInfo = Depends(require_invoice_key) +): try: url = str(lnurl_decode(code)) domain = urlparse(url).netloc diff --git a/lnbits/core/views/payment_api.py b/lnbits/core/views/payment_api.py index 75c96f045..ec5e027b8 100644 --- a/lnbits/core/views/payment_api.py +++ b/lnbits/core/views/payment_api.py @@ -26,11 +26,11 @@ from lnbits.core.models import ( CreateInvoice, CreateLnurl, DecodePayment, + KeyType, Payment, PaymentFilters, PaymentHistoryPoint, Wallet, - WalletType, ) from lnbits.db import Filters, Page from lnbits.decorators import ( @@ -252,7 +252,7 @@ async def api_payments_create( wallet: WalletTypeInfo = Depends(require_invoice_key), invoice_data: CreateInvoice = Body(...), ): - if invoice_data.out is True and wallet.wallet_type == WalletType.admin: + if invoice_data.out is True and wallet.key_type == KeyType.admin: if not invoice_data.bolt11: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, diff --git a/lnbits/core/views/wallet_api.py b/lnbits/core/views/wallet_api.py index 94a8dadfb..543dd1bd6 100644 --- a/lnbits/core/views/wallet_api.py +++ b/lnbits/core/views/wallet_api.py @@ -8,8 +8,8 @@ from fastapi import ( from lnbits.core.models import ( CreateWallet, + KeyType, Wallet, - WalletType, ) from lnbits.decorators import ( WalletTypeInfo, @@ -28,7 +28,7 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"]) @wallet_router.get("") async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): - if wallet.wallet_type == WalletType.admin: + if wallet.key_type == KeyType.admin: return { "id": wallet.wallet.id, "name": wallet.wallet.name, diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 862ff2656..497c8ffed 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -17,23 +17,32 @@ from lnbits.core.crud import ( get_user, get_wallet_for_key, ) -from lnbits.core.models import User, Wallet, WalletType, WalletTypeInfo +from lnbits.core.models import KeyType, User, WalletTypeInfo from lnbits.db import Filter, Filters, TFilterModel from lnbits.settings import AuthMethods, settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False) +api_key_header = APIKeyHeader( + name="X-API-KEY", + auto_error=False, + description="Admin or Invoice key for wallet API's", +) +api_key_query = APIKeyQuery( + name="api-key", + auto_error=False, + description="Admin or Invoice key for wallet API's", +) + class KeyChecker(SecurityBase): def __init__( self, - scheme_name: Optional[str] = None, - auto_error: bool = True, api_key: Optional[str] = None, + expected_key_type: Optional[KeyType] = None, ): - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error: bool = auto_error - self._key_type: WalletType = WalletType.invoice + self.auto_error: bool = True + self.expected_key_type = expected_key_type self._api_key = api_key if api_key: openapi_model = APIKey( @@ -49,185 +58,82 @@ class KeyChecker(SecurityBase): name="X-API-KEY", description="Wallet API Key - HEADER", ) - self.wallet: Optional[Wallet] = None self.model: APIKey = openapi_model - async def __call__(self, request: Request): - try: - key_value = ( - self._api_key - if self._api_key - else request.headers.get("X-API-KEY") or request.query_params["api-key"] - ) - # FIXME: Find another way to validate the key. A fetch from DB should be - # avoided here. Also, we should not return the wallet here - thats - # silly. Possibly store it in a Redis DB - wallet = await get_wallet_for_key(key_value, self._key_type) - if not wallet: - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="Invalid key or wallet.", - ) - self.wallet = wallet - except KeyError as exc: + async def __call__(self, request: Request) -> WalletTypeInfo: + + key_value = ( + self._api_key + if self._api_key + else request.headers.get("X-API-KEY") or request.query_params.get("api-key") + ) + + if not key_value: raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing." - ) from exc + status_code=HTTPStatus.UNAUTHORIZED, + detail="No Api Key provided.", + ) + wallet = await get_wallet_for_key(key_value) -class WalletInvoiceKeyChecker(KeyChecker): - """ - WalletInvoiceKeyChecker will ensure that the provided invoice - wallet key is correct and populate g().wallet with the wallet - for the key in `X-API-key`. + if not wallet: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Wallet not found.", + ) - The checker will raise an HTTPException when the key is wrong in some ways. - """ + if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Invalid adminkey.", + ) - def __init__( - self, - scheme_name: Optional[str] = None, - auto_error: bool = True, - api_key: Optional[str] = None, - ): - super().__init__(scheme_name, auto_error, api_key) - self._key_type = WalletType.invoice + if ( + wallet.user != settings.super_user + and wallet.user not in settings.lnbits_admin_users + and settings.lnbits_admin_extensions + and request["path"].split("/")[1] in settings.lnbits_admin_extensions + ): + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail="User not authorized for this extension.", + ) - -class WalletAdminKeyChecker(KeyChecker): - """ - WalletAdminKeyChecker will ensure that the provided admin - wallet key is correct and populate g().wallet with the wallet - for the key in `X-API-key`. - - The checker will raise an HTTPException when the key is wrong in some ways. - """ - - def __init__( - self, - scheme_name: Optional[str] = None, - auto_error: bool = True, - api_key: Optional[str] = None, - ): - super().__init__(scheme_name, auto_error, api_key) - self._key_type = WalletType.admin - - -api_key_header = APIKeyHeader( - name="X-API-KEY", - auto_error=False, - description="Admin or Invoice key for wallet API's", -) -api_key_query = APIKeyQuery( - name="api-key", - auto_error=False, - description="Admin or Invoice key for wallet API's", -) + key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice + return WalletTypeInfo(key_type, wallet) async def get_key_type( - r: Request, + request: Request, api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query), ) -> WalletTypeInfo: - token = api_key_header or api_key_query - - if not token: - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="Invoice (or Admin) key required.", - ) - - for wallet_type, wallet_checker in zip( - [WalletType.admin, WalletType.invoice], - [WalletAdminKeyChecker, WalletInvoiceKeyChecker], - ): - try: - checker = wallet_checker(api_key=token) - await checker.__call__(r) - if checker.wallet is None: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist." - ) - wallet = WalletTypeInfo(wallet_type, checker.wallet) - if ( - wallet.wallet.user != settings.super_user - and wallet.wallet.user not in settings.lnbits_admin_users - ) and ( - settings.lnbits_admin_extensions - and r["path"].split("/")[1] in settings.lnbits_admin_extensions - ): - raise HTTPException( - status_code=HTTPStatus.FORBIDDEN, - detail="User not authorized for this extension.", - ) - return wallet - except HTTPException as exc: - if exc.status_code == HTTPStatus.BAD_REQUEST: - raise - elif exc.status_code == HTTPStatus.UNAUTHORIZED: - # we pass this in case it is not an invoice key, nor an admin key, - # and then return NOT_FOUND at the end of this block - pass - else: - raise - except Exception: - raise - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist." - ) + check: KeyChecker = KeyChecker(api_key=api_key_header or api_key_query) + return await check(request) async def require_admin_key( - r: Request, + request: Request, api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query), -): - token = api_key_header or api_key_query - - if not token: - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="Admin key required.", - ) - - wallet = await get_key_type(r, token) - - if wallet.wallet_type != 0: - # If wallet type is not admin then return the unauthorized status - # This also covers when the user passes an invalid key type - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required." - ) - else: - return wallet +) -> WalletTypeInfo: + check: KeyChecker = KeyChecker( + api_key=api_key_header or api_key_query, + expected_key_type=KeyType.admin, + ) + return await check(request) async def require_invoice_key( - r: Request, + request: Request, api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query), -): - token = api_key_header or api_key_query - - if not token: - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="Invoice (or Admin) key required.", - ) - - wallet = await get_key_type(r, token) - - if ( - wallet.wallet_type != WalletType.admin - and wallet.wallet_type != WalletType.invoice - ): - raise HTTPException( - status_code=HTTPStatus.UNAUTHORIZED, - detail="Invoice (or Admin) key required.", - ) - else: - return wallet +) -> WalletTypeInfo: + check: KeyChecker = KeyChecker( + api_key=api_key_header or api_key_query, + expected_key_type=KeyType.invoice, + ) + return await check(request) async def check_access_token( @@ -255,8 +161,15 @@ async def check_user_exists( user = await get_user(account.id) assert user, "User not found for account." - if not user.admin and r["path"].split("/")[1] in settings.lnbits_admin_extensions: - raise HTTPException(HTTPStatus.FORBIDDEN, "User not authorized for extension.") + if ( + user.id != settings.super_user + and user.id not in settings.lnbits_admin_users + and settings.lnbits_admin_extensions + and r["path"].split("/")[1] in settings.lnbits_admin_extensions + ): + raise HTTPException( + HTTPStatus.UNAUTHORIZED, "User not authorized for extension." + ) return user diff --git a/tests/api/test_api.py b/tests/api/test_api.py index a8e6a09a3..26020b9e1 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -15,6 +15,10 @@ from ..helpers import ( # create account POST /api/v1/account @pytest.mark.asyncio async def test_create_account(client): + settings.lnbits_allow_new_accounts = False + response = await client.post("/api/v1/account", json={"name": "test"}) + assert response.status_code == 403 + settings.lnbits_allow_new_accounts = True response = await client.post("/api/v1/account", json={"name": "test"}) assert response.status_code == 200 result = response.json() @@ -39,6 +43,16 @@ async def test_create_wallet_and_delete(client, adminkey_headers_to): assert "balance_msat" in result assert "id" in result assert "adminkey" in result + + invalid_response = await client.delete( + "/api/v1/wallet", + headers={ + "X-Api-Key": result["inkey"], + "Content-type": "application/json", + }, + ) + assert invalid_response.status_code == 401 + response = await client.delete( "/api/v1/wallet", headers={