feat: introduce NotFoundError

instead of asserting and returning Optional on get operation in crud, we
through and NotFoundError. in the context of an API request the FastApi
errorhandler will catch it and throw an 404 like we have for Invoice-
and PaymentError.

In non-api context we can simly try catch the notfound like so:
```python
try:
  url = await get_tinyurl("test")
except NotFoundError:
  # url does not exist do whatever
  pass
```
This commit is contained in:
dni ⚡ 2024-09-05 08:08:31 +02:00
parent 8aa1716e32
commit 12b3659f02
No known key found for this signature in database
GPG Key ID: D1F416F29AD26E87
3 changed files with 88 additions and 106 deletions

View File

@ -10,6 +10,7 @@ from passlib.context import CryptContext
from lnbits.core.db import db
from lnbits.core.models import PaymentState
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
from lnbits.exceptions import NotFoundError
from lnbits.extension_manager import (
InstallableExtension,
PayToEnableInfo,
@ -61,10 +62,7 @@ async def create_account(
(user_id, username, password, email, extra, now, now),
)
new_account = await get_account(user_id=user_id, conn=conn)
assert new_account, "Newly created account couldn't be retrieved"
return new_account
return await get_account(user_id=user_id, conn=conn)
async def update_account(
@ -72,9 +70,8 @@ async def update_account(
username: Optional[str] = None,
email: Optional[str] = None,
user_config: Optional[UserConfig] = None,
) -> Optional[User]:
) -> User:
user = await get_account(user_id)
assert user, "User not found"
if email:
assert not user.email or email == user.email, "Cannot change email."
@ -106,9 +103,7 @@ async def update_account(
),
)
user = await get_user(user_id)
assert user, "Updated account couldn't be retrieved"
return user
return await get_user(user_id)
async def delete_account(user_id: str, conn: Optional[Connection] = None) -> None:
@ -151,9 +146,7 @@ async def get_accounts(
)
async def get_account(
user_id: str, conn: Optional[Connection] = None
) -> Optional[User]:
async def get_account(user_id: str, conn: Optional[Connection] = None) -> User:
row = await (conn or db).fetchone(
"""
SELECT id, email, username, created_at, updated_at, extra
@ -161,9 +154,10 @@ async def get_account(
""",
(user_id,),
)
user = User(**row) if row else None
if user and row["extra"]:
if not row:
raise NotFoundError()
user = User(**row)
if row["extra"]:
user.config = UserConfig(**json.loads(row["extra"]))
return user
@ -211,7 +205,7 @@ async def verify_user_password(user_id: str, password: str) -> bool:
# todo: , conn: Optional[Connection] = None ??
async def update_user_password(data: UpdateUserPassword) -> Optional[User]:
async def update_user_password(data: UpdateUserPassword) -> User:
assert data.password == data.password_repeat, "Passwords do not match."
# old accounts do not have a pasword
@ -235,9 +229,7 @@ async def update_user_password(data: UpdateUserPassword) -> Optional[User]:
),
)
user = await get_user(data.user_id)
assert user, "Updated account couldn't be retrieved"
return user
return await get_user(data.user_id)
async def get_account_by_username(
@ -277,7 +269,7 @@ async def get_account_by_username_or_email(
return user
async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
async def get_user(user_id: str, conn: Optional[Connection] = None) -> User:
user = await (conn or db).fetchone(
"""
SELECT id, email, username, pass, extra, created_at, updated_at
@ -285,21 +277,20 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
""",
(user_id,),
)
if not user:
raise NotFoundError()
if user:
extensions = await get_user_active_extensions_ids(user_id, conn)
wallets = await (conn or db).fetchall(
"""
SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet = wallets.id
), 0) AS balance_msat
FROM wallets
WHERE "user" = ? and wallets.deleted = false
""",
(user_id,),
)
else:
return None
extensions = await get_user_active_extensions_ids(user_id, conn)
wallets = await (conn or db).fetchall(
"""
SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet = wallets.id
), 0) AS balance_msat
FROM wallets
WHERE "user" = ? and wallets.deleted = false
""",
(user_id,),
)
return User(
id=user["id"],
@ -534,10 +525,7 @@ async def create_wallet(
),
)
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet
return await get_wallet(wallet_id=wallet_id, conn=conn)
async def update_wallet(
@ -564,9 +552,7 @@ async def update_wallet(
""",
tuple(values),
)
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert wallet, "updated created wallet couldn't be retrieved"
return wallet
return await get_wallet(wallet_id=wallet_id, conn=conn)
async def delete_wallet(
@ -637,9 +623,7 @@ async def delete_unused_wallets(
)
async def get_wallet(
wallet_id: str, conn: Optional[Connection] = None
) -> Optional[Wallet]:
async def get_wallet(wallet_id: str, conn: Optional[Connection] = None) -> Wallet:
row = await (conn or db).fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0)
@ -647,8 +631,10 @@ async def get_wallet(
""",
(wallet_id,),
)
if not row:
raise NotFoundError()
return Wallet(**row) if row else None
return Wallet(**row)
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]:
@ -722,7 +708,7 @@ async def get_standalone_payment(
async def get_wallet_payment(
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
) -> Payment:
row = await (conn or db).fetchone(
"""
SELECT *
@ -731,8 +717,10 @@ async def get_wallet_payment(
""",
(wallet_id, payment_hash),
)
if not row:
raise NotFoundError()
return Payment.from_row(row) if row else None
return Payment.from_row(row)
async def get_latest_payments_by_extension(ext_name: str, ext_id: str, limit: int = 5):
@ -927,10 +915,7 @@ async def create_payment(
),
)
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
return await get_wallet_payment(wallet_id, payment_hash, conn=conn)
async def update_payment_status(
@ -1045,10 +1030,7 @@ async def get_payments_history(
)
if wallet_id:
wallet = await get_wallet(wallet_id)
if wallet:
balance = wallet.balance_msat
else:
raise ValueError("Unknown wallet")
balance = wallet.balance_msat
else:
balance = await get_total_balance()
@ -1210,7 +1192,7 @@ async def delete_dbversion(*, ext_id: str, conn: Optional[Connection] = None) ->
# -------
async def create_tinyurl(domain: str, endless: bool, wallet: str):
async def create_tinyurl(domain: str, endless: bool, wallet: str) -> TinyURL:
tinyurl_id = shortuuid.uuid()[:8]
await db.execute(
"INSERT INTO tiny_url (id, url, endless, wallet) VALUES (?, ?, ?, ?)",
@ -1224,12 +1206,14 @@ async def create_tinyurl(domain: str, endless: bool, wallet: str):
return await get_tinyurl(tinyurl_id)
async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
async def get_tinyurl(tinyurl_id: str) -> TinyURL:
row = await db.fetchone(
"SELECT * FROM tiny_url WHERE id = ?",
(tinyurl_id,),
)
return TinyURL.from_row(row) if row else None
if not row:
raise NotFoundError()
return TinyURL.from_row(row)
async def get_tinyurl_by_url(url: str) -> List[TinyURL]:
@ -1267,9 +1251,7 @@ async def create_webpush_settings(webpush_settings: dict):
return await get_webpush_settings()
async def get_webpush_subscription(
endpoint: str, user: str
) -> Optional[WebPushSubscription]:
async def get_webpush_subscription(endpoint: str, user: str) -> WebPushSubscription:
row = await db.fetchone(
"""SELECT * FROM webpush_subscriptions WHERE endpoint = ? AND "user" = ?""",
(
@ -1277,12 +1259,12 @@ async def get_webpush_subscription(
user,
),
)
return WebPushSubscription(**dict(row)) if row else None
if not row:
raise NotFoundError()
return WebPushSubscription(**row)
async def get_webpush_subscriptions_for_user(
user: str,
) -> List[WebPushSubscription]:
async def get_webpush_subscriptions_for_user(user: str) -> List[WebPushSubscription]:
rows = await db.fetchall(
"""SELECT * FROM webpush_subscriptions WHERE "user" = ?""",
(user,),
@ -1305,9 +1287,7 @@ async def create_webpush_subscription(
host,
),
)
subscription = await get_webpush_subscription(endpoint, user)
assert subscription, "Newly created webpush subscription couldn't be retrieved"
return subscription
return await get_webpush_subscription(endpoint, user)
async def delete_webpush_subscription(endpoint: str, user: str) -> int:

View File

@ -32,16 +32,10 @@ async def api_create_tinyurl(
url: str, endless: bool = False, wallet: WalletTypeInfo = Depends(require_admin_key)
):
tinyurls = await get_tinyurl_by_url(url)
try:
for tinyurl in tinyurls:
if tinyurl:
if tinyurl.wallet == wallet.wallet.id:
return tinyurl
return await create_tinyurl(url, endless, wallet.wallet.id)
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Unable to create tinyurl"
) from exc
for tinyurl in tinyurls:
if tinyurl.wallet == wallet.wallet.id:
return tinyurl
return await create_tinyurl(url, endless, wallet.wallet.id)
@tinyurl_router.get(
@ -52,18 +46,12 @@ async def api_create_tinyurl(
async def api_get_tinyurl(
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
):
try:
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl:
if tinyurl.wallet == wallet.wallet.id:
return tinyurl
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl.wallet != wallet.wallet.id:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."
)
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="Unable to fetch tinyurl"
) from exc
return tinyurl
@tinyurl_router.delete(
@ -74,19 +62,13 @@ async def api_get_tinyurl(
async def api_delete_tinyurl(
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)
):
try:
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl:
if tinyurl.wallet == wallet.wallet.id:
await delete_tinyurl(tinyurl_id)
return {"deleted": True}
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl.wallet != wallet.wallet.id:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."
)
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail="Unable to delete"
) from exc
await delete_tinyurl(tinyurl_id)
return {"deleted": True}
@tinyurl_router.get(
@ -96,10 +78,5 @@ async def api_delete_tinyurl(
)
async def api_tinyurl(tinyurl_id: str):
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl:
response = RedirectResponse(url=tinyurl.url)
return response
else:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail="unable to find tinyurl"
)
response = RedirectResponse(url=tinyurl.url)
return response

View File

@ -11,24 +11,40 @@ from loguru import logger
from .helpers import template_renderer
class PaymentError(Exception):
class LnbitsError(Exception):
"""Base class for all exceptions in lnbits."""
class PaymentError(LnbitsError):
"""raised by fundingsource pay_invoice operations when an error occurs"""
def __init__(self, message: str, status: str = "pending"):
self.message = message
self.status = status
class InvoiceError(Exception):
class InvoiceError(LnbitsError):
"""raised by fundingsource create_invoice operations when an error occurs"""
def __init__(self, message: str, status: str = "pending"):
self.message = message
self.status = status
class NotFoundError(LnbitsError):
"""
Raised by crud operations when a resource is not found.
Raises (401) error in api context.
"""
def register_exception_handlers(app: FastAPI):
register_exception_handler(app)
register_request_validation_exception_handler(app)
register_http_exception_handler(app)
register_payment_error_handler(app)
register_invoice_error_handler(app)
register_not_found_error_handler(app)
def render_html_error(request: Request, exc: Exception) -> Optional[Response]:
@ -115,3 +131,12 @@ def register_invoice_error_handler(app: FastAPI):
status_code=520,
content={"detail": exc.message, "status": exc.status},
)
def register_not_found_error_handler(app: FastAPI):
@app.exception_handler(NotFoundError)
async def notfound_error_handler(request: Request, exc: NotFoundError):
return JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"{exc!s}"},
)