diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 39fbad93b..11e39b6f1 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -10,6 +10,7 @@ from passlib.context import CryptContext from lnbits.core.db import db from lnbits.core.models import PaymentState from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page +from lnbits.exceptions import NotFoundError from lnbits.extension_manager import ( InstallableExtension, PayToEnableInfo, @@ -61,10 +62,7 @@ async def create_account( (user_id, username, password, email, extra, now, now), ) - new_account = await get_account(user_id=user_id, conn=conn) - assert new_account, "Newly created account couldn't be retrieved" - - return new_account + return await get_account(user_id=user_id, conn=conn) async def update_account( @@ -72,9 +70,8 @@ async def update_account( username: Optional[str] = None, email: Optional[str] = None, user_config: Optional[UserConfig] = None, -) -> Optional[User]: +) -> User: user = await get_account(user_id) - assert user, "User not found" if email: assert not user.email or email == user.email, "Cannot change email." @@ -106,9 +103,7 @@ async def update_account( ), ) - user = await get_user(user_id) - assert user, "Updated account couldn't be retrieved" - return user + return await get_user(user_id) async def delete_account(user_id: str, conn: Optional[Connection] = None) -> None: @@ -151,9 +146,7 @@ async def get_accounts( ) -async def get_account( - user_id: str, conn: Optional[Connection] = None -) -> Optional[User]: +async def get_account(user_id: str, conn: Optional[Connection] = None) -> User: row = await (conn or db).fetchone( """ SELECT id, email, username, created_at, updated_at, extra @@ -161,9 +154,10 @@ async def get_account( """, (user_id,), ) - - user = User(**row) if row else None - if user and row["extra"]: + if not row: + raise NotFoundError() + user = User(**row) + if row["extra"]: user.config = UserConfig(**json.loads(row["extra"])) return user @@ -211,7 +205,7 @@ async def verify_user_password(user_id: str, password: str) -> bool: # todo: , conn: Optional[Connection] = None ?? -async def update_user_password(data: UpdateUserPassword) -> Optional[User]: +async def update_user_password(data: UpdateUserPassword) -> User: assert data.password == data.password_repeat, "Passwords do not match." # old accounts do not have a pasword @@ -235,9 +229,7 @@ async def update_user_password(data: UpdateUserPassword) -> Optional[User]: ), ) - user = await get_user(data.user_id) - assert user, "Updated account couldn't be retrieved" - return user + return await get_user(data.user_id) async def get_account_by_username( @@ -277,7 +269,7 @@ async def get_account_by_username_or_email( return user -async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]: +async def get_user(user_id: str, conn: Optional[Connection] = None) -> User: user = await (conn or db).fetchone( """ SELECT id, email, username, pass, extra, created_at, updated_at @@ -285,21 +277,20 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[ """, (user_id,), ) + if not user: + raise NotFoundError() - if user: - extensions = await get_user_active_extensions_ids(user_id, conn) - wallets = await (conn or db).fetchall( - """ - SELECT *, COALESCE(( - SELECT balance FROM balances WHERE wallet = wallets.id - ), 0) AS balance_msat - FROM wallets - WHERE "user" = ? and wallets.deleted = false - """, - (user_id,), - ) - else: - return None + extensions = await get_user_active_extensions_ids(user_id, conn) + wallets = await (conn or db).fetchall( + """ + SELECT *, COALESCE(( + SELECT balance FROM balances WHERE wallet = wallets.id + ), 0) AS balance_msat + FROM wallets + WHERE "user" = ? and wallets.deleted = false + """, + (user_id,), + ) return User( id=user["id"], @@ -534,10 +525,7 @@ async def create_wallet( ), ) - new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn) - assert new_wallet, "Newly created wallet couldn't be retrieved" - - return new_wallet + return await get_wallet(wallet_id=wallet_id, conn=conn) async def update_wallet( @@ -564,9 +552,7 @@ async def update_wallet( """, tuple(values), ) - wallet = await get_wallet(wallet_id=wallet_id, conn=conn) - assert wallet, "updated created wallet couldn't be retrieved" - return wallet + return await get_wallet(wallet_id=wallet_id, conn=conn) async def delete_wallet( @@ -637,9 +623,7 @@ async def delete_unused_wallets( ) -async def get_wallet( - wallet_id: str, conn: Optional[Connection] = None -) -> Optional[Wallet]: +async def get_wallet(wallet_id: str, conn: Optional[Connection] = None) -> Wallet: row = await (conn or db).fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) @@ -647,8 +631,10 @@ async def get_wallet( """, (wallet_id,), ) + if not row: + raise NotFoundError() - return Wallet(**row) if row else None + return Wallet(**row) async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]: @@ -722,7 +708,7 @@ async def get_standalone_payment( async def get_wallet_payment( wallet_id: str, payment_hash: str, conn: Optional[Connection] = None -) -> Optional[Payment]: +) -> Payment: row = await (conn or db).fetchone( """ SELECT * @@ -731,8 +717,10 @@ async def get_wallet_payment( """, (wallet_id, payment_hash), ) + if not row: + raise NotFoundError() - return Payment.from_row(row) if row else None + return Payment.from_row(row) async def get_latest_payments_by_extension(ext_name: str, ext_id: str, limit: int = 5): @@ -927,10 +915,7 @@ async def create_payment( ), ) - new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn) - assert new_payment, "Newly created payment couldn't be retrieved" - - return new_payment + return await get_wallet_payment(wallet_id, payment_hash, conn=conn) async def update_payment_status( @@ -1045,10 +1030,7 @@ async def get_payments_history( ) if wallet_id: wallet = await get_wallet(wallet_id) - if wallet: - balance = wallet.balance_msat - else: - raise ValueError("Unknown wallet") + balance = wallet.balance_msat else: balance = await get_total_balance() @@ -1210,7 +1192,7 @@ async def delete_dbversion(*, ext_id: str, conn: Optional[Connection] = None) -> # ------- -async def create_tinyurl(domain: str, endless: bool, wallet: str): +async def create_tinyurl(domain: str, endless: bool, wallet: str) -> TinyURL: tinyurl_id = shortuuid.uuid()[:8] await db.execute( "INSERT INTO tiny_url (id, url, endless, wallet) VALUES (?, ?, ?, ?)", @@ -1224,12 +1206,14 @@ async def create_tinyurl(domain: str, endless: bool, wallet: str): return await get_tinyurl(tinyurl_id) -async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]: +async def get_tinyurl(tinyurl_id: str) -> TinyURL: row = await db.fetchone( "SELECT * FROM tiny_url WHERE id = ?", (tinyurl_id,), ) - return TinyURL.from_row(row) if row else None + if not row: + raise NotFoundError() + return TinyURL.from_row(row) async def get_tinyurl_by_url(url: str) -> List[TinyURL]: @@ -1267,9 +1251,7 @@ async def create_webpush_settings(webpush_settings: dict): return await get_webpush_settings() -async def get_webpush_subscription( - endpoint: str, user: str -) -> Optional[WebPushSubscription]: +async def get_webpush_subscription(endpoint: str, user: str) -> WebPushSubscription: row = await db.fetchone( """SELECT * FROM webpush_subscriptions WHERE endpoint = ? AND "user" = ?""", ( @@ -1277,12 +1259,12 @@ async def get_webpush_subscription( user, ), ) - return WebPushSubscription(**dict(row)) if row else None + if not row: + raise NotFoundError() + return WebPushSubscription(**row) -async def get_webpush_subscriptions_for_user( - user: str, -) -> List[WebPushSubscription]: +async def get_webpush_subscriptions_for_user(user: str) -> List[WebPushSubscription]: rows = await db.fetchall( """SELECT * FROM webpush_subscriptions WHERE "user" = ?""", (user,), @@ -1305,9 +1287,7 @@ async def create_webpush_subscription( host, ), ) - subscription = await get_webpush_subscription(endpoint, user) - assert subscription, "Newly created webpush subscription couldn't be retrieved" - return subscription + return await get_webpush_subscription(endpoint, user) async def delete_webpush_subscription(endpoint: str, user: str) -> int: diff --git a/lnbits/core/views/tinyurl_api.py b/lnbits/core/views/tinyurl_api.py index deb90eefd..cc2edaa09 100644 --- a/lnbits/core/views/tinyurl_api.py +++ b/lnbits/core/views/tinyurl_api.py @@ -32,16 +32,10 @@ async def api_create_tinyurl( url: str, endless: bool = False, wallet: WalletTypeInfo = Depends(require_admin_key) ): tinyurls = await get_tinyurl_by_url(url) - try: - for tinyurl in tinyurls: - if tinyurl: - if tinyurl.wallet == wallet.wallet.id: - return tinyurl - return await create_tinyurl(url, endless, wallet.wallet.id) - except Exception as exc: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail="Unable to create tinyurl" - ) from exc + for tinyurl in tinyurls: + if tinyurl.wallet == wallet.wallet.id: + return tinyurl + return await create_tinyurl(url, endless, wallet.wallet.id) @tinyurl_router.get( @@ -52,18 +46,12 @@ async def api_create_tinyurl( async def api_get_tinyurl( tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key) ): - try: - tinyurl = await get_tinyurl(tinyurl_id) - if tinyurl: - if tinyurl.wallet == wallet.wallet.id: - return tinyurl + tinyurl = await get_tinyurl(tinyurl_id) + if tinyurl.wallet != wallet.wallet.id: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided." ) - except Exception as exc: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="Unable to fetch tinyurl" - ) from exc + return tinyurl @tinyurl_router.delete( @@ -74,19 +62,13 @@ async def api_get_tinyurl( async def api_delete_tinyurl( tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_admin_key) ): - try: - tinyurl = await get_tinyurl(tinyurl_id) - if tinyurl: - if tinyurl.wallet == wallet.wallet.id: - await delete_tinyurl(tinyurl_id) - return {"deleted": True} + tinyurl = await get_tinyurl(tinyurl_id) + if tinyurl.wallet != wallet.wallet.id: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided." ) - except Exception as exc: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail="Unable to delete" - ) from exc + await delete_tinyurl(tinyurl_id) + return {"deleted": True} @tinyurl_router.get( @@ -96,10 +78,5 @@ async def api_delete_tinyurl( ) async def api_tinyurl(tinyurl_id: str): tinyurl = await get_tinyurl(tinyurl_id) - if tinyurl: - response = RedirectResponse(url=tinyurl.url) - return response - else: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, detail="unable to find tinyurl" - ) + response = RedirectResponse(url=tinyurl.url) + return response diff --git a/lnbits/exceptions.py b/lnbits/exceptions.py index c66582a68..50cc73a55 100644 --- a/lnbits/exceptions.py +++ b/lnbits/exceptions.py @@ -11,24 +11,40 @@ from loguru import logger from .helpers import template_renderer -class PaymentError(Exception): +class LnbitsError(Exception): + """Base class for all exceptions in lnbits.""" + + +class PaymentError(LnbitsError): + """raised by fundingsource pay_invoice operations when an error occurs""" + def __init__(self, message: str, status: str = "pending"): self.message = message self.status = status -class InvoiceError(Exception): +class InvoiceError(LnbitsError): + """raised by fundingsource create_invoice operations when an error occurs""" + def __init__(self, message: str, status: str = "pending"): self.message = message self.status = status +class NotFoundError(LnbitsError): + """ + Raised by crud operations when a resource is not found. + Raises (401) error in api context. + """ + + def register_exception_handlers(app: FastAPI): register_exception_handler(app) register_request_validation_exception_handler(app) register_http_exception_handler(app) register_payment_error_handler(app) register_invoice_error_handler(app) + register_not_found_error_handler(app) def render_html_error(request: Request, exc: Exception) -> Optional[Response]: @@ -115,3 +131,12 @@ def register_invoice_error_handler(app: FastAPI): status_code=520, content={"detail": exc.message, "status": exc.status}, ) + + +def register_not_found_error_handler(app: FastAPI): + @app.exception_handler(NotFoundError) + async def notfound_error_handler(request: Request, exc: NotFoundError): + return JSONResponse( + status_code=HTTPStatus.NOT_FOUND, + content={"detail": f"{exc!s}"}, + )