refactor: fix duplicate keychecker (#2339)

* refactor: fix duplicate keychecker

- refactor KeyChecker to be more approachable
- only 1 sql query needed even if you use `get_key_type`
- rename `WalletType` to `KeyType` wallet type was misleading

fix test

sorting

* fixup!

* revert 404
This commit is contained in:
dni ⚡
2024-05-13 16:26:25 +02:00
committed by GitHub
parent 9f8942a921
commit 6730c6ed67
7 changed files with 101 additions and 177 deletions

View File

@ -8,7 +8,6 @@ import shortuuid
from passlib.context import CryptContext from passlib.context import CryptContext
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.core.models import WalletType
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.extension_manager import InstallableExtension from lnbits.extension_manager import InstallableExtension
from lnbits.settings import ( from lnbits.settings import (
@ -628,7 +627,6 @@ async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[W
async def get_wallet_for_key( async def get_wallet_for_key(
key: str, key: str,
key_type: WalletType = WalletType.invoice,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Optional[Wallet]: ) -> Optional[Wallet]:
row = await (conn or db).fetchone( row = await (conn or db).fetchone(
@ -643,9 +641,6 @@ async def get_wallet_for_key(
if not row: if not row:
return None return None
if key_type == WalletType.admin and row["adminkey"] != key:
return None
return Wallet(**row) return Wallet(**row)

View File

@ -68,7 +68,7 @@ class Wallet(BaseWallet):
return await get_standalone_payment(payment_hash) return await get_standalone_payment(payment_hash)
class WalletType(Enum): class KeyType(Enum):
admin = 0 admin = 0
invoice = 1 invoice = 1
invalid = 2 invalid = 2
@ -80,7 +80,7 @@ class WalletType(Enum):
@dataclass @dataclass
class WalletTypeInfo: class WalletTypeInfo:
wallet_type: WalletType key_type: KeyType
wallet: Wallet wallet: Wallet

View File

@ -25,8 +25,8 @@ from lnbits.core.models import (
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
check_user_exists, check_user_exists,
get_key_type,
require_admin_key, require_admin_key,
require_invoice_key,
) )
from lnbits.lnurl import decode as lnurl_decode from lnbits.lnurl import decode as lnurl_decode
from lnbits.settings import settings from lnbits.settings import settings
@ -67,7 +67,7 @@ async def api_wallets(user: User = Depends(check_user_exists)) -> List[BaseWalle
async def api_create_account(data: CreateWallet) -> Wallet: async def api_create_account(data: CreateWallet) -> Wallet:
if not settings.new_accounts_allowed: if not settings.new_accounts_allowed:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.FORBIDDEN,
detail="Account creation is disabled.", detail="Account creation is disabled.",
) )
account = await create_account() account = await create_account()
@ -75,7 +75,9 @@ async def api_create_account(data: CreateWallet) -> Wallet:
@api_router.get("/api/v1/lnurlscan/{code}") @api_router.get("/api/v1/lnurlscan/{code}")
async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type)): async def api_lnurlscan(
code: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
):
try: try:
url = str(lnurl_decode(code)) url = str(lnurl_decode(code))
domain = urlparse(url).netloc domain = urlparse(url).netloc

View File

@ -26,11 +26,11 @@ from lnbits.core.models import (
CreateInvoice, CreateInvoice,
CreateLnurl, CreateLnurl,
DecodePayment, DecodePayment,
KeyType,
Payment, Payment,
PaymentFilters, PaymentFilters,
PaymentHistoryPoint, PaymentHistoryPoint,
Wallet, Wallet,
WalletType,
) )
from lnbits.db import Filters, Page from lnbits.db import Filters, Page
from lnbits.decorators import ( from lnbits.decorators import (
@ -252,7 +252,7 @@ async def api_payments_create(
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
invoice_data: CreateInvoice = Body(...), invoice_data: CreateInvoice = Body(...),
): ):
if invoice_data.out is True and wallet.wallet_type == WalletType.admin: if invoice_data.out is True and wallet.key_type == KeyType.admin:
if not invoice_data.bolt11: if not invoice_data.bolt11:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,

View File

@ -8,8 +8,8 @@ from fastapi import (
from lnbits.core.models import ( from lnbits.core.models import (
CreateWallet, CreateWallet,
KeyType,
Wallet, Wallet,
WalletType,
) )
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
@ -28,7 +28,7 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"])
@wallet_router.get("") @wallet_router.get("")
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)): async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
if wallet.wallet_type == WalletType.admin: if wallet.key_type == KeyType.admin:
return { return {
"id": wallet.wallet.id, "id": wallet.wallet.id,
"name": wallet.wallet.name, "name": wallet.wallet.name,

View File

@ -17,23 +17,32 @@ from lnbits.core.crud import (
get_user, get_user,
get_wallet_for_key, get_wallet_for_key,
) )
from lnbits.core.models import User, Wallet, WalletType, WalletTypeInfo from lnbits.core.models import KeyType, User, WalletTypeInfo
from lnbits.db import Filter, Filters, TFilterModel from lnbits.db import Filter, Filters, TFilterModel
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
api_key_header = APIKeyHeader(
name="X-API-KEY",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
api_key_query = APIKeyQuery(
name="api-key",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
class KeyChecker(SecurityBase): class KeyChecker(SecurityBase):
def __init__( def __init__(
self, self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None, api_key: Optional[str] = None,
expected_key_type: Optional[KeyType] = None,
): ):
self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error: bool = True
self.auto_error: bool = auto_error self.expected_key_type = expected_key_type
self._key_type: WalletType = WalletType.invoice
self._api_key = api_key self._api_key = api_key
if api_key: if api_key:
openapi_model = APIKey( openapi_model = APIKey(
@ -49,185 +58,82 @@ class KeyChecker(SecurityBase):
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
self.wallet: Optional[Wallet] = None
self.model: APIKey = openapi_model self.model: APIKey = openapi_model
async def __call__(self, request: Request): async def __call__(self, request: Request) -> WalletTypeInfo:
try:
key_value = ( key_value = (
self._api_key self._api_key
if self._api_key if self._api_key
else request.headers.get("X-API-KEY") or request.query_params["api-key"] else request.headers.get("X-API-KEY") or request.query_params.get("api-key")
) )
# FIXME: Find another way to validate the key. A fetch from DB should be
# avoided here. Also, we should not return the wallet here - thats if not key_value:
# silly. Possibly store it in a Redis DB
wallet = await get_wallet_for_key(key_value, self._key_type)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invalid key or wallet.",
)
self.wallet = wallet
except KeyError as exc:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing." status_code=HTTPStatus.UNAUTHORIZED,
) from exc detail="No Api Key provided.",
)
wallet = await get_wallet_for_key(key_value)
class WalletInvoiceKeyChecker(KeyChecker): if not wallet:
""" raise HTTPException(
WalletInvoiceKeyChecker will ensure that the provided invoice status_code=HTTPStatus.NOT_FOUND,
wallet key is correct and populate g().wallet with the wallet detail="Wallet not found.",
for the key in `X-API-key`. )
The checker will raise an HTTPException when the key is wrong in some ways. if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
""" raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invalid adminkey.",
)
def __init__( if (
self, wallet.user != settings.super_user
scheme_name: Optional[str] = None, and wallet.user not in settings.lnbits_admin_users
auto_error: bool = True, and settings.lnbits_admin_extensions
api_key: Optional[str] = None, and request["path"].split("/")[1] in settings.lnbits_admin_extensions
): ):
super().__init__(scheme_name, auto_error, api_key) raise HTTPException(
self._key_type = WalletType.invoice status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.",
)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
class WalletAdminKeyChecker(KeyChecker): return WalletTypeInfo(key_type, wallet)
"""
WalletAdminKeyChecker will ensure that the provided admin
wallet key is correct and populate g().wallet with the wallet
for the key in `X-API-key`.
The checker will raise an HTTPException when the key is wrong in some ways.
"""
def __init__(
self,
scheme_name: Optional[str] = None,
auto_error: bool = True,
api_key: Optional[str] = None,
):
super().__init__(scheme_name, auto_error, api_key)
self._key_type = WalletType.admin
api_key_header = APIKeyHeader(
name="X-API-KEY",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
api_key_query = APIKeyQuery(
name="api-key",
auto_error=False,
description="Admin or Invoice key for wallet API's",
)
async def get_key_type( async def get_key_type(
r: Request, request: Request,
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query: str = Security(api_key_query),
) -> WalletTypeInfo: ) -> WalletTypeInfo:
token = api_key_header or api_key_query check: KeyChecker = KeyChecker(api_key=api_key_header or api_key_query)
return await check(request)
if not token:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invoice (or Admin) key required.",
)
for wallet_type, wallet_checker in zip(
[WalletType.admin, WalletType.invoice],
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
):
try:
checker = wallet_checker(api_key=token)
await checker.__call__(r)
if checker.wallet is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
)
wallet = WalletTypeInfo(wallet_type, checker.wallet)
if (
wallet.wallet.user != settings.super_user
and wallet.wallet.user not in settings.lnbits_admin_users
) and (
settings.lnbits_admin_extensions
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.",
)
return wallet
except HTTPException as exc:
if exc.status_code == HTTPStatus.BAD_REQUEST:
raise
elif exc.status_code == HTTPStatus.UNAUTHORIZED:
# we pass this in case it is not an invoice key, nor an admin key,
# and then return NOT_FOUND at the end of this block
pass
else:
raise
except Exception:
raise
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
)
async def require_admin_key( async def require_admin_key(
r: Request, request: Request,
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query: str = Security(api_key_query),
): ) -> WalletTypeInfo:
token = api_key_header or api_key_query check: KeyChecker = KeyChecker(
api_key=api_key_header or api_key_query,
if not token: expected_key_type=KeyType.admin,
raise HTTPException( )
status_code=HTTPStatus.UNAUTHORIZED, return await check(request)
detail="Admin key required.",
)
wallet = await get_key_type(r, token)
if wallet.wallet_type != 0:
# If wallet type is not admin then return the unauthorized status
# This also covers when the user passes an invalid key type
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
)
else:
return wallet
async def require_invoice_key( async def require_invoice_key(
r: Request, request: Request,
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query: str = Security(api_key_query),
): ) -> WalletTypeInfo:
token = api_key_header or api_key_query check: KeyChecker = KeyChecker(
api_key=api_key_header or api_key_query,
if not token: expected_key_type=KeyType.invoice,
raise HTTPException( )
status_code=HTTPStatus.UNAUTHORIZED, return await check(request)
detail="Invoice (or Admin) key required.",
)
wallet = await get_key_type(r, token)
if (
wallet.wallet_type != WalletType.admin
and wallet.wallet_type != WalletType.invoice
):
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Invoice (or Admin) key required.",
)
else:
return wallet
async def check_access_token( async def check_access_token(
@ -255,8 +161,15 @@ async def check_user_exists(
user = await get_user(account.id) user = await get_user(account.id)
assert user, "User not found for account." assert user, "User not found for account."
if not user.admin and r["path"].split("/")[1] in settings.lnbits_admin_extensions: if (
raise HTTPException(HTTPStatus.FORBIDDEN, "User not authorized for extension.") user.id != settings.super_user
and user.id not in settings.lnbits_admin_users
and settings.lnbits_admin_extensions
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
)
return user return user

View File

@ -15,6 +15,10 @@ from ..helpers import (
# create account POST /api/v1/account # create account POST /api/v1/account
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_account(client): async def test_create_account(client):
settings.lnbits_allow_new_accounts = False
response = await client.post("/api/v1/account", json={"name": "test"})
assert response.status_code == 403
settings.lnbits_allow_new_accounts = True
response = await client.post("/api/v1/account", json={"name": "test"}) response = await client.post("/api/v1/account", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()
@ -39,6 +43,16 @@ async def test_create_wallet_and_delete(client, adminkey_headers_to):
assert "balance_msat" in result assert "balance_msat" in result
assert "id" in result assert "id" in result
assert "adminkey" in result assert "adminkey" in result
invalid_response = await client.delete(
"/api/v1/wallet",
headers={
"X-Api-Key": result["inkey"],
"Content-type": "application/json",
},
)
assert invalid_response.status_code == 401
response = await client.delete( response = await client.delete(
"/api/v1/wallet", "/api/v1/wallet",
headers={ headers={