diff --git a/lnbits/core/models.py b/lnbits/core/models.py index ef0a38e04..83214c745 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -453,3 +453,8 @@ class BalanceDelta(BaseModel): @property def delta_msats(self): return self.node_balance_msats - self.lnbits_balance_msats + + +class SimpleStatus(BaseModel): + success: bool + message: str diff --git a/lnbits/core/services.py b/lnbits/core/services.py index a9a705211..b3cf795f6 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -17,7 +17,11 @@ from py_vapid.utils import b64urlencode from lnbits.core.db import db from lnbits.db import Connection -from lnbits.decorators import WalletTypeInfo, require_admin_key +from lnbits.decorators import ( + WalletTypeInfo, + check_user_extension_access, + require_admin_key, +) from lnbits.helpers import url_for from lnbits.lnurl import LnurlErrorResponse from lnbits.lnurl import decode as decode_lnurl @@ -300,18 +304,13 @@ async def pay_invoice( # do the balance check wallet = await get_wallet(wallet_id, conn=conn) assert wallet, "Wallet for balancecheck could not be fetched" - if wallet.balance_msat < 0: - logger.debug("balance is too low, deleting temporary payment") - if ( - not internal_checking_id - and wallet.balance_msat > -fee_reserve_total_msat - ): - raise PaymentError( - f"You must reserve at least ({round(fee_reserve_total_msat/1000)}" - " sat) to cover potential routing fees.", - status="failed", - ) - raise PaymentError("Insufficient balance.", status="failed") + _check_wallet_balance(wallet, fee_reserve_total_msat, internal_checking_id) + + if extra and "tag" in extra: + # check if the payment is made for an extension that the user disabled + status = await check_user_extension_access(wallet.user, extra["tag"]) + if not status.success: + raise PaymentError(status.message) if internal_checking_id: service_fee_msat = service_fee(invoice.amount_msat, internal=True) @@ -402,6 +401,22 @@ async def pay_invoice( return invoice.payment_hash +def _check_wallet_balance( + wallet: Wallet, + fee_reserve_total_msat: int, + internal_checking_id: Optional[str] = None, +): + if wallet.balance_msat < 0: + logger.debug("balance is too low, deleting temporary payment") + if not internal_checking_id and wallet.balance_msat > -fee_reserve_total_msat: + raise PaymentError( + f"You must reserve at least ({round(fee_reserve_total_msat/1000)}" + " sat) to cover potential routing fees.", + status="failed", + ) + raise PaymentError("Insufficient balance.", status="failed") + + async def check_wallet_limits(wallet_id, conn, amount_msat): await check_time_limit_between_transactions(conn, wallet_id) await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 26bf38eeb..f6d04b1b7 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -18,7 +18,7 @@ from lnbits.core.crud import ( get_user_active_extensions_ids, get_wallet_for_key, ) -from lnbits.core.models import KeyType, User, WalletTypeInfo +from lnbits.core.models import KeyType, SimpleStatus, User, WalletTypeInfo from lnbits.db import Filter, Filters, TFilterModel from lnbits.settings import AuthMethods, settings @@ -210,27 +210,36 @@ def parse_filters(model: Type[TFilterModel]): return dependency -async def _check_user_extension_access(user_id: str, current_path: str): +async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus: """ Check if the user has access to a particular extension. Raises HTTP Forbidden if the user is not allowed. """ - path = current_path.split("/") - ext_id = path[3] if path[1] == "upgrades" else path[1] if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id): - raise HTTPException( - HTTPStatus.FORBIDDEN, - f"User not authorized for extension '{ext_id}'.", + return SimpleStatus( + success=False, message=f"User not authorized for extension '{ext_id}'." ) if settings.is_extension_id(ext_id): ext_ids = await get_user_active_extensions_ids(user_id) if ext_id not in ext_ids: - raise HTTPException( - HTTPStatus.FORBIDDEN, - f"User extension '{ext_id}' not enabled.", + return SimpleStatus( + success=False, message=f"User extension '{ext_id}' not enabled." ) + return SimpleStatus(success=True, message="OK") + + +async def _check_user_extension_access(user_id: str, current_path: str): + path = current_path.split("/") + ext_id = path[3] if path[1] == "upgrades" else path[1] + status = await check_user_extension_access(user_id, ext_id) + if not status.success: + raise HTTPException( + HTTPStatus.FORBIDDEN, + status.message, + ) + async def _get_account_from_token(access_token): try: