mirror of
https://github.com/lnbits/lnbits.git
synced 2025-06-29 18:10:46 +02:00
refactor: fix duplicate keychecker (#2339)
* refactor: fix duplicate keychecker - refactor KeyChecker to be more approachable - only 1 sql query needed even if you use `get_key_type` - rename `WalletType` to `KeyType` wallet type was misleading fix test sorting * fixup! * revert 404
This commit is contained in:
@ -8,7 +8,6 @@ import shortuuid
|
|||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
from lnbits.core.db import db
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import WalletType
|
|
||||||
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
|
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
|
||||||
from lnbits.extension_manager import InstallableExtension
|
from lnbits.extension_manager import InstallableExtension
|
||||||
from lnbits.settings import (
|
from lnbits.settings import (
|
||||||
@ -628,7 +627,6 @@ async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[W
|
|||||||
|
|
||||||
async def get_wallet_for_key(
|
async def get_wallet_for_key(
|
||||||
key: str,
|
key: str,
|
||||||
key_type: WalletType = WalletType.invoice,
|
|
||||||
conn: Optional[Connection] = None,
|
conn: Optional[Connection] = None,
|
||||||
) -> Optional[Wallet]:
|
) -> Optional[Wallet]:
|
||||||
row = await (conn or db).fetchone(
|
row = await (conn or db).fetchone(
|
||||||
@ -643,9 +641,6 @@ async def get_wallet_for_key(
|
|||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if key_type == WalletType.admin and row["adminkey"] != key:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Wallet(**row)
|
return Wallet(**row)
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class Wallet(BaseWallet):
|
|||||||
return await get_standalone_payment(payment_hash)
|
return await get_standalone_payment(payment_hash)
|
||||||
|
|
||||||
|
|
||||||
class WalletType(Enum):
|
class KeyType(Enum):
|
||||||
admin = 0
|
admin = 0
|
||||||
invoice = 1
|
invoice = 1
|
||||||
invalid = 2
|
invalid = 2
|
||||||
@ -80,7 +80,7 @@ class WalletType(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WalletTypeInfo:
|
class WalletTypeInfo:
|
||||||
wallet_type: WalletType
|
key_type: KeyType
|
||||||
wallet: Wallet
|
wallet: Wallet
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ from lnbits.core.models import (
|
|||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
WalletTypeInfo,
|
WalletTypeInfo,
|
||||||
check_user_exists,
|
check_user_exists,
|
||||||
get_key_type,
|
|
||||||
require_admin_key,
|
require_admin_key,
|
||||||
|
require_invoice_key,
|
||||||
)
|
)
|
||||||
from lnbits.lnurl import decode as lnurl_decode
|
from lnbits.lnurl import decode as lnurl_decode
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
@ -67,7 +67,7 @@ async def api_wallets(user: User = Depends(check_user_exists)) -> List[BaseWalle
|
|||||||
async def api_create_account(data: CreateWallet) -> Wallet:
|
async def api_create_account(data: CreateWallet) -> Wallet:
|
||||||
if not settings.new_accounts_allowed:
|
if not settings.new_accounts_allowed:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
detail="Account creation is disabled.",
|
detail="Account creation is disabled.",
|
||||||
)
|
)
|
||||||
account = await create_account()
|
account = await create_account()
|
||||||
@ -75,7 +75,9 @@ async def api_create_account(data: CreateWallet) -> Wallet:
|
|||||||
|
|
||||||
|
|
||||||
@api_router.get("/api/v1/lnurlscan/{code}")
|
@api_router.get("/api/v1/lnurlscan/{code}")
|
||||||
async def api_lnurlscan(code: str, wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_lnurlscan(
|
||||||
|
code: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
url = str(lnurl_decode(code))
|
url = str(lnurl_decode(code))
|
||||||
domain = urlparse(url).netloc
|
domain = urlparse(url).netloc
|
||||||
|
@ -26,11 +26,11 @@ from lnbits.core.models import (
|
|||||||
CreateInvoice,
|
CreateInvoice,
|
||||||
CreateLnurl,
|
CreateLnurl,
|
||||||
DecodePayment,
|
DecodePayment,
|
||||||
|
KeyType,
|
||||||
Payment,
|
Payment,
|
||||||
PaymentFilters,
|
PaymentFilters,
|
||||||
PaymentHistoryPoint,
|
PaymentHistoryPoint,
|
||||||
Wallet,
|
Wallet,
|
||||||
WalletType,
|
|
||||||
)
|
)
|
||||||
from lnbits.db import Filters, Page
|
from lnbits.db import Filters, Page
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
@ -252,7 +252,7 @@ async def api_payments_create(
|
|||||||
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
invoice_data: CreateInvoice = Body(...),
|
invoice_data: CreateInvoice = Body(...),
|
||||||
):
|
):
|
||||||
if invoice_data.out is True and wallet.wallet_type == WalletType.admin:
|
if invoice_data.out is True and wallet.key_type == KeyType.admin:
|
||||||
if not invoice_data.bolt11:
|
if not invoice_data.bolt11:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
@ -8,8 +8,8 @@ from fastapi import (
|
|||||||
|
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
CreateWallet,
|
CreateWallet,
|
||||||
|
KeyType,
|
||||||
Wallet,
|
Wallet,
|
||||||
WalletType,
|
|
||||||
)
|
)
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
WalletTypeInfo,
|
WalletTypeInfo,
|
||||||
@ -28,7 +28,7 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"])
|
|||||||
|
|
||||||
@wallet_router.get("")
|
@wallet_router.get("")
|
||||||
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
async def api_wallet(wallet: WalletTypeInfo = Depends(get_key_type)):
|
||||||
if wallet.wallet_type == WalletType.admin:
|
if wallet.key_type == KeyType.admin:
|
||||||
return {
|
return {
|
||||||
"id": wallet.wallet.id,
|
"id": wallet.wallet.id,
|
||||||
"name": wallet.wallet.name,
|
"name": wallet.wallet.name,
|
||||||
|
@ -17,23 +17,32 @@ from lnbits.core.crud import (
|
|||||||
get_user,
|
get_user,
|
||||||
get_wallet_for_key,
|
get_wallet_for_key,
|
||||||
)
|
)
|
||||||
from lnbits.core.models import User, Wallet, WalletType, WalletTypeInfo
|
from lnbits.core.models import KeyType, User, WalletTypeInfo
|
||||||
from lnbits.db import Filter, Filters, TFilterModel
|
from lnbits.db import Filter, Filters, TFilterModel
|
||||||
from lnbits.settings import AuthMethods, settings
|
from lnbits.settings import AuthMethods, settings
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
|
||||||
|
|
||||||
|
api_key_header = APIKeyHeader(
|
||||||
|
name="X-API-KEY",
|
||||||
|
auto_error=False,
|
||||||
|
description="Admin or Invoice key for wallet API's",
|
||||||
|
)
|
||||||
|
api_key_query = APIKeyQuery(
|
||||||
|
name="api-key",
|
||||||
|
auto_error=False,
|
||||||
|
description="Admin or Invoice key for wallet API's",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeyChecker(SecurityBase):
|
class KeyChecker(SecurityBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scheme_name: Optional[str] = None,
|
|
||||||
auto_error: bool = True,
|
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
expected_key_type: Optional[KeyType] = None,
|
||||||
):
|
):
|
||||||
self.scheme_name = scheme_name or self.__class__.__name__
|
self.auto_error: bool = True
|
||||||
self.auto_error: bool = auto_error
|
self.expected_key_type = expected_key_type
|
||||||
self._key_type: WalletType = WalletType.invoice
|
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
if api_key:
|
if api_key:
|
||||||
openapi_model = APIKey(
|
openapi_model = APIKey(
|
||||||
@ -49,185 +58,82 @@ class KeyChecker(SecurityBase):
|
|||||||
name="X-API-KEY",
|
name="X-API-KEY",
|
||||||
description="Wallet API Key - HEADER",
|
description="Wallet API Key - HEADER",
|
||||||
)
|
)
|
||||||
self.wallet: Optional[Wallet] = None
|
|
||||||
self.model: APIKey = openapi_model
|
self.model: APIKey = openapi_model
|
||||||
|
|
||||||
async def __call__(self, request: Request):
|
async def __call__(self, request: Request) -> WalletTypeInfo:
|
||||||
try:
|
|
||||||
key_value = (
|
key_value = (
|
||||||
self._api_key
|
self._api_key
|
||||||
if self._api_key
|
if self._api_key
|
||||||
else request.headers.get("X-API-KEY") or request.query_params["api-key"]
|
else request.headers.get("X-API-KEY") or request.query_params.get("api-key")
|
||||||
)
|
)
|
||||||
# FIXME: Find another way to validate the key. A fetch from DB should be
|
|
||||||
# avoided here. Also, we should not return the wallet here - thats
|
if not key_value:
|
||||||
# silly. Possibly store it in a Redis DB
|
|
||||||
wallet = await get_wallet_for_key(key_value, self._key_type)
|
|
||||||
if not wallet:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.UNAUTHORIZED,
|
|
||||||
detail="Invalid key or wallet.",
|
|
||||||
)
|
|
||||||
self.wallet = wallet
|
|
||||||
except KeyError as exc:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing."
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
) from exc
|
detail="No Api Key provided.",
|
||||||
|
)
|
||||||
|
|
||||||
|
wallet = await get_wallet_for_key(key_value)
|
||||||
|
|
||||||
class WalletInvoiceKeyChecker(KeyChecker):
|
if not wallet:
|
||||||
"""
|
raise HTTPException(
|
||||||
WalletInvoiceKeyChecker will ensure that the provided invoice
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
wallet key is correct and populate g().wallet with the wallet
|
detail="Wallet not found.",
|
||||||
for the key in `X-API-key`.
|
)
|
||||||
|
|
||||||
The checker will raise an HTTPException when the key is wrong in some ways.
|
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
|
||||||
"""
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.UNAUTHORIZED,
|
||||||
|
detail="Invalid adminkey.",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
if (
|
||||||
self,
|
wallet.user != settings.super_user
|
||||||
scheme_name: Optional[str] = None,
|
and wallet.user not in settings.lnbits_admin_users
|
||||||
auto_error: bool = True,
|
and settings.lnbits_admin_extensions
|
||||||
api_key: Optional[str] = None,
|
and request["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||||
):
|
):
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
raise HTTPException(
|
||||||
self._key_type = WalletType.invoice
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
|
detail="User not authorized for this extension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||||
class WalletAdminKeyChecker(KeyChecker):
|
return WalletTypeInfo(key_type, wallet)
|
||||||
"""
|
|
||||||
WalletAdminKeyChecker will ensure that the provided admin
|
|
||||||
wallet key is correct and populate g().wallet with the wallet
|
|
||||||
for the key in `X-API-key`.
|
|
||||||
|
|
||||||
The checker will raise an HTTPException when the key is wrong in some ways.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
scheme_name: Optional[str] = None,
|
|
||||||
auto_error: bool = True,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
super().__init__(scheme_name, auto_error, api_key)
|
|
||||||
self._key_type = WalletType.admin
|
|
||||||
|
|
||||||
|
|
||||||
api_key_header = APIKeyHeader(
|
|
||||||
name="X-API-KEY",
|
|
||||||
auto_error=False,
|
|
||||||
description="Admin or Invoice key for wallet API's",
|
|
||||||
)
|
|
||||||
api_key_query = APIKeyQuery(
|
|
||||||
name="api-key",
|
|
||||||
auto_error=False,
|
|
||||||
description="Admin or Invoice key for wallet API's",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_key_type(
|
async def get_key_type(
|
||||||
r: Request,
|
request: Request,
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
api_key_query: str = Security(api_key_query),
|
api_key_query: str = Security(api_key_query),
|
||||||
) -> WalletTypeInfo:
|
) -> WalletTypeInfo:
|
||||||
token = api_key_header or api_key_query
|
check: KeyChecker = KeyChecker(api_key=api_key_header or api_key_query)
|
||||||
|
return await check(request)
|
||||||
if not token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.UNAUTHORIZED,
|
|
||||||
detail="Invoice (or Admin) key required.",
|
|
||||||
)
|
|
||||||
|
|
||||||
for wallet_type, wallet_checker in zip(
|
|
||||||
[WalletType.admin, WalletType.invoice],
|
|
||||||
[WalletAdminKeyChecker, WalletInvoiceKeyChecker],
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
checker = wallet_checker(api_key=token)
|
|
||||||
await checker.__call__(r)
|
|
||||||
if checker.wallet is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
|
|
||||||
)
|
|
||||||
wallet = WalletTypeInfo(wallet_type, checker.wallet)
|
|
||||||
if (
|
|
||||||
wallet.wallet.user != settings.super_user
|
|
||||||
and wallet.wallet.user not in settings.lnbits_admin_users
|
|
||||||
) and (
|
|
||||||
settings.lnbits_admin_extensions
|
|
||||||
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.FORBIDDEN,
|
|
||||||
detail="User not authorized for this extension.",
|
|
||||||
)
|
|
||||||
return wallet
|
|
||||||
except HTTPException as exc:
|
|
||||||
if exc.status_code == HTTPStatus.BAD_REQUEST:
|
|
||||||
raise
|
|
||||||
elif exc.status_code == HTTPStatus.UNAUTHORIZED:
|
|
||||||
# we pass this in case it is not an invoice key, nor an admin key,
|
|
||||||
# and then return NOT_FOUND at the end of this block
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.NOT_FOUND, detail="Wallet does not exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def require_admin_key(
|
async def require_admin_key(
|
||||||
r: Request,
|
request: Request,
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
api_key_query: str = Security(api_key_query),
|
api_key_query: str = Security(api_key_query),
|
||||||
):
|
) -> WalletTypeInfo:
|
||||||
token = api_key_header or api_key_query
|
check: KeyChecker = KeyChecker(
|
||||||
|
api_key=api_key_header or api_key_query,
|
||||||
if not token:
|
expected_key_type=KeyType.admin,
|
||||||
raise HTTPException(
|
)
|
||||||
status_code=HTTPStatus.UNAUTHORIZED,
|
return await check(request)
|
||||||
detail="Admin key required.",
|
|
||||||
)
|
|
||||||
|
|
||||||
wallet = await get_key_type(r, token)
|
|
||||||
|
|
||||||
if wallet.wallet_type != 0:
|
|
||||||
# If wallet type is not admin then return the unauthorized status
|
|
||||||
# This also covers when the user passes an invalid key type
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.UNAUTHORIZED, detail="Admin key required."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return wallet
|
|
||||||
|
|
||||||
|
|
||||||
async def require_invoice_key(
|
async def require_invoice_key(
|
||||||
r: Request,
|
request: Request,
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
api_key_query: str = Security(api_key_query),
|
api_key_query: str = Security(api_key_query),
|
||||||
):
|
) -> WalletTypeInfo:
|
||||||
token = api_key_header or api_key_query
|
check: KeyChecker = KeyChecker(
|
||||||
|
api_key=api_key_header or api_key_query,
|
||||||
if not token:
|
expected_key_type=KeyType.invoice,
|
||||||
raise HTTPException(
|
)
|
||||||
status_code=HTTPStatus.UNAUTHORIZED,
|
return await check(request)
|
||||||
detail="Invoice (or Admin) key required.",
|
|
||||||
)
|
|
||||||
|
|
||||||
wallet = await get_key_type(r, token)
|
|
||||||
|
|
||||||
if (
|
|
||||||
wallet.wallet_type != WalletType.admin
|
|
||||||
and wallet.wallet_type != WalletType.invoice
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.UNAUTHORIZED,
|
|
||||||
detail="Invoice (or Admin) key required.",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return wallet
|
|
||||||
|
|
||||||
|
|
||||||
async def check_access_token(
|
async def check_access_token(
|
||||||
@ -255,8 +161,15 @@ async def check_user_exists(
|
|||||||
user = await get_user(account.id)
|
user = await get_user(account.id)
|
||||||
assert user, "User not found for account."
|
assert user, "User not found for account."
|
||||||
|
|
||||||
if not user.admin and r["path"].split("/")[1] in settings.lnbits_admin_extensions:
|
if (
|
||||||
raise HTTPException(HTTPStatus.FORBIDDEN, "User not authorized for extension.")
|
user.id != settings.super_user
|
||||||
|
and user.id not in settings.lnbits_admin_users
|
||||||
|
and settings.lnbits_admin_extensions
|
||||||
|
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
|
||||||
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
@ -15,6 +15,10 @@ from ..helpers import (
|
|||||||
# create account POST /api/v1/account
|
# create account POST /api/v1/account
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_account(client):
|
async def test_create_account(client):
|
||||||
|
settings.lnbits_allow_new_accounts = False
|
||||||
|
response = await client.post("/api/v1/account", json={"name": "test"})
|
||||||
|
assert response.status_code == 403
|
||||||
|
settings.lnbits_allow_new_accounts = True
|
||||||
response = await client.post("/api/v1/account", json={"name": "test"})
|
response = await client.post("/api/v1/account", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
@ -39,6 +43,16 @@ async def test_create_wallet_and_delete(client, adminkey_headers_to):
|
|||||||
assert "balance_msat" in result
|
assert "balance_msat" in result
|
||||||
assert "id" in result
|
assert "id" in result
|
||||||
assert "adminkey" in result
|
assert "adminkey" in result
|
||||||
|
|
||||||
|
invalid_response = await client.delete(
|
||||||
|
"/api/v1/wallet",
|
||||||
|
headers={
|
||||||
|
"X-Api-Key": result["inkey"],
|
||||||
|
"Content-type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invalid_response.status_code == 401
|
||||||
|
|
||||||
response = await client.delete(
|
response = await client.delete(
|
||||||
"/api/v1/wallet",
|
"/api/v1/wallet",
|
||||||
headers={
|
headers={
|
||||||
|
Reference in New Issue
Block a user