mirror of
https://github.com/lnbits/lnbits.git
synced 2025-03-26 17:51:53 +01:00
feat: introduce NotFoundError
instead of asserting and returning Optional on get operation in crud, we through and NotFoundError. in the context of an API request the FastApi errorhandler will catch it and throw an 404 like we have for Invoice- and PaymentError. In non-api context we can simly try catch the notfound like so: ```python try: url = await get_tinyurl("test") except NotFoundError: # url does not exist do whatever pass ```
This commit is contained in:
parent
8aa1716e32
commit
12b3659f02
@ -10,6 +10,7 @@ from passlib.context import CryptContext
|
||||
from lnbits.core.db import db
|
||||
from lnbits.core.models import PaymentState
|
||||
from lnbits.db import DB_TYPE, SQLITE, Connection, Database, Filters, Page
|
||||
from lnbits.exceptions import NotFoundError
|
||||
from lnbits.extension_manager import (
|
||||
InstallableExtension,
|
||||
PayToEnableInfo,
|
||||
@ -61,10 +62,7 @@ async def create_account(
|
||||
(user_id, username, password, email, extra, now, now),
|
||||
)
|
||||
|
||||
new_account = await get_account(user_id=user_id, conn=conn)
|
||||
assert new_account, "Newly created account couldn't be retrieved"
|
||||
|
||||
return new_account
|
||||
return await get_account(user_id=user_id, conn=conn)
|
||||
|
||||
|
||||
async def update_account(
|
||||
@ -72,9 +70,8 @@ async def update_account(
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
user_config: Optional[UserConfig] = None,
|
||||
) -> Optional[User]:
|
||||
) -> User:
|
||||
user = await get_account(user_id)
|
||||
assert user, "User not found"
|
||||
|
||||
if email:
|
||||
assert not user.email or email == user.email, "Cannot change email."
|
||||
@ -106,9 +103,7 @@ async def update_account(
|
||||
),
|
||||
)
|
||||
|
||||
user = await get_user(user_id)
|
||||
assert user, "Updated account couldn't be retrieved"
|
||||
return user
|
||||
return await get_user(user_id)
|
||||
|
||||
|
||||
async def delete_account(user_id: str, conn: Optional[Connection] = None) -> None:
|
||||
@ -151,9 +146,7 @@ async def get_accounts(
|
||||
)
|
||||
|
||||
|
||||
async def get_account(
|
||||
user_id: str, conn: Optional[Connection] = None
|
||||
) -> Optional[User]:
|
||||
async def get_account(user_id: str, conn: Optional[Connection] = None) -> User:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT id, email, username, created_at, updated_at, extra
|
||||
@ -161,9 +154,10 @@ async def get_account(
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
user = User(**row) if row else None
|
||||
if user and row["extra"]:
|
||||
if not row:
|
||||
raise NotFoundError()
|
||||
user = User(**row)
|
||||
if row["extra"]:
|
||||
user.config = UserConfig(**json.loads(row["extra"]))
|
||||
return user
|
||||
|
||||
@ -211,7 +205,7 @@ async def verify_user_password(user_id: str, password: str) -> bool:
|
||||
|
||||
|
||||
# todo: , conn: Optional[Connection] = None ??
|
||||
async def update_user_password(data: UpdateUserPassword) -> Optional[User]:
|
||||
async def update_user_password(data: UpdateUserPassword) -> User:
|
||||
assert data.password == data.password_repeat, "Passwords do not match."
|
||||
|
||||
# old accounts do not have a pasword
|
||||
@ -235,9 +229,7 @@ async def update_user_password(data: UpdateUserPassword) -> Optional[User]:
|
||||
),
|
||||
)
|
||||
|
||||
user = await get_user(data.user_id)
|
||||
assert user, "Updated account couldn't be retrieved"
|
||||
return user
|
||||
return await get_user(data.user_id)
|
||||
|
||||
|
||||
async def get_account_by_username(
|
||||
@ -277,7 +269,7 @@ async def get_account_by_username_or_email(
|
||||
return user
|
||||
|
||||
|
||||
async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
|
||||
async def get_user(user_id: str, conn: Optional[Connection] = None) -> User:
|
||||
user = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT id, email, username, pass, extra, created_at, updated_at
|
||||
@ -285,21 +277,20 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
if not user:
|
||||
raise NotFoundError()
|
||||
|
||||
if user:
|
||||
extensions = await get_user_active_extensions_ids(user_id, conn)
|
||||
wallets = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT *, COALESCE((
|
||||
SELECT balance FROM balances WHERE wallet = wallets.id
|
||||
), 0) AS balance_msat
|
||||
FROM wallets
|
||||
WHERE "user" = ? and wallets.deleted = false
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
extensions = await get_user_active_extensions_ids(user_id, conn)
|
||||
wallets = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT *, COALESCE((
|
||||
SELECT balance FROM balances WHERE wallet = wallets.id
|
||||
), 0) AS balance_msat
|
||||
FROM wallets
|
||||
WHERE "user" = ? and wallets.deleted = false
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
return User(
|
||||
id=user["id"],
|
||||
@ -534,10 +525,7 @@ async def create_wallet(
|
||||
),
|
||||
)
|
||||
|
||||
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert new_wallet, "Newly created wallet couldn't be retrieved"
|
||||
|
||||
return new_wallet
|
||||
return await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
|
||||
|
||||
async def update_wallet(
|
||||
@ -564,9 +552,7 @@ async def update_wallet(
|
||||
""",
|
||||
tuple(values),
|
||||
)
|
||||
wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert wallet, "updated created wallet couldn't be retrieved"
|
||||
return wallet
|
||||
return await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
|
||||
|
||||
async def delete_wallet(
|
||||
@ -637,9 +623,7 @@ async def delete_unused_wallets(
|
||||
)
|
||||
|
||||
|
||||
async def get_wallet(
|
||||
wallet_id: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Wallet]:
|
||||
async def get_wallet(wallet_id: str, conn: Optional[Connection] = None) -> Wallet:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0)
|
||||
@ -647,8 +631,10 @@ async def get_wallet(
|
||||
""",
|
||||
(wallet_id,),
|
||||
)
|
||||
if not row:
|
||||
raise NotFoundError()
|
||||
|
||||
return Wallet(**row) if row else None
|
||||
return Wallet(**row)
|
||||
|
||||
|
||||
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]:
|
||||
@ -722,7 +708,7 @@ async def get_standalone_payment(
|
||||
|
||||
async def get_wallet_payment(
|
||||
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Payment]:
|
||||
) -> Payment:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *
|
||||
@ -731,8 +717,10 @@ async def get_wallet_payment(
|
||||
""",
|
||||
(wallet_id, payment_hash),
|
||||
)
|
||||
if not row:
|
||||
raise NotFoundError()
|
||||
|
||||
return Payment.from_row(row) if row else None
|
||||
return Payment.from_row(row)
|
||||
|
||||
|
||||
async def get_latest_payments_by_extension(ext_name: str, ext_id: str, limit: int = 5):
|
||||
@ -927,10 +915,7 @@ async def create_payment(
|
||||
),
|
||||
)
|
||||
|
||||
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
assert new_payment, "Newly created payment couldn't be retrieved"
|
||||
|
||||
return new_payment
|
||||
return await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
|
||||
|
||||
async def update_payment_status(
|
||||
@ -1045,10 +1030,7 @@ async def get_payments_history(
|
||||
)
|
||||
if wallet_id:
|
||||
wallet = await get_wallet(wallet_id)
|
||||
if wallet:
|
||||
balance = wallet.balance_msat
|
||||
else:
|
||||
raise ValueError("Unknown wallet")
|
||||
balance = wallet.balance_msat
|
||||
else:
|
||||
balance = await get_total_balance()
|
||||
|
||||
@ -1210,7 +1192,7 @@ async def delete_dbversion(*, ext_id: str, conn: Optional[Connection] = None) ->
|
||||
# -------
|
||||
|
||||
|
||||
async def create_tinyurl(domain: str, endless: bool, wallet: str):
|
||||
async def create_tinyurl(domain: str, endless: bool, wallet: str) -> TinyURL:
|
||||
tinyurl_id = shortuuid.uuid()[:8]
|
||||
await db.execute(
|
||||
"INSERT INTO tiny_url (id, url, endless, wallet) VALUES (?, ?, ?, ?)",
|
||||
@ -1224,12 +1206,14 @@ async def create_tinyurl(domain: str, endless: bool, wallet: str):
|
||||
return await get_tinyurl(tinyurl_id)
|
||||
|
||||
|
||||
async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
|
||||
async def get_tinyurl(tinyurl_id: str) -> TinyURL:
|
||||
row = await db.fetchone(
|
||||
"SELECT * FROM tiny_url WHERE id = ?",
|
||||
(tinyurl_id,),
|
||||
)
|
||||
return TinyURL.from_row(row) if row else None
|
||||
if not row:
|
||||
raise NotFoundError()
|
||||
return TinyURL.from_row(row)
|
||||
|
||||
|
||||
async def get_tinyurl_by_url(url: str) -> List[TinyURL]:
|
||||
@ -1267,9 +1251,7 @@ async def create_webpush_settings(webpush_settings: dict):
|
||||
return await get_webpush_settings()
|
||||
|
||||
|
||||
async def get_webpush_subscription(
|
||||
endpoint: str, user: str
|
||||
) -> Optional[WebPushSubscription]:
|
||||
async def get_webpush_subscription(endpoint: str, user: str) -> WebPushSubscription:
|
||||
row = await db.fetchone(
|
||||
"""SELECT * FROM webpush_subscriptions WHERE endpoint = ? AND "user" = ?""",
|
||||
(
|
||||
@ -1277,12 +1259,12 @@ async def get_webpush_subscription(
|
||||
user,
|
||||
),
|
||||
)
|
||||
return WebPushSubscription(**dict(row)) if row else None
|
||||
if not row:
|
||||
raise NotFoundError()
|
||||
return WebPushSubscription(**row)
|
||||
|
||||
|
||||
async def get_webpush_subscriptions_for_user(
|
||||
user: str,
|
||||
) -> List[WebPushSubscription]:
|
||||
async def get_webpush_subscriptions_for_user(user: str) -> List[WebPushSubscription]:
|
||||
rows = await db.fetchall(
|
||||
"""SELECT * FROM webpush_subscriptions WHERE "user" = ?""",
|
||||
(user,),
|
||||
@ -1305,9 +1287,7 @@ async def create_webpush_subscription(
|
||||
host,
|
||||
),
|
||||
)
|
||||
subscription = await get_webpush_subscription(endpoint, user)
|
||||
assert subscription, "Newly created webpush subscription couldn't be retrieved"
|
||||
return subscription
|
||||
return await get_webpush_subscription(endpoint, user)
|
||||
|
||||
|
||||
async def delete_webpush_subscription(endpoint: str, user: str) -> int:
|
||||
|
@ -32,16 +32,10 @@ async def api_create_tinyurl(
|
||||
url: str, endless: bool = False, wallet: WalletTypeInfo = Depends(require_admin_key)
|
||||
):
|
||||
tinyurls = await get_tinyurl_by_url(url)
|
||||
try:
|
||||
for tinyurl in tinyurls:
|
||||
if tinyurl:
|
||||
if tinyurl.wallet == wallet.wallet.id:
|
||||
return tinyurl
|
||||
return await create_tinyurl(url, endless, wallet.wallet.id)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST, detail="Unable to create tinyurl"
|
||||
) from exc
|
||||
for tinyurl in tinyurls:
|
||||
if tinyurl.wallet == wallet.wallet.id:
|
||||
return tinyurl
|
||||
return await create_tinyurl(url, endless, wallet.wallet.id)
|
||||
|
||||
|
||||
@tinyurl_router.get(
|
||||
@ -52,18 +46,12 @@ async def api_create_tinyurl(
|
||||
async def api_get_tinyurl(
|
||||
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
|
||||
):
|
||||
try:
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl:
|
||||
if tinyurl.wallet == wallet.wallet.id:
|
||||
return tinyurl
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl.wallet != wallet.wallet.id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND, detail="Unable to fetch tinyurl"
|
||||
) from exc
|
||||
return tinyurl
|
||||
|
||||
|
||||
@tinyurl_router.delete(
|
||||
@ -74,19 +62,13 @@ async def api_get_tinyurl(
|
||||
async def api_delete_tinyurl(
|
||||
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)
|
||||
):
|
||||
try:
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl:
|
||||
if tinyurl.wallet == wallet.wallet.id:
|
||||
await delete_tinyurl(tinyurl_id)
|
||||
return {"deleted": True}
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl.wallet != wallet.wallet.id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST, detail="Unable to delete"
|
||||
) from exc
|
||||
await delete_tinyurl(tinyurl_id)
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@tinyurl_router.get(
|
||||
@ -96,10 +78,5 @@ async def api_delete_tinyurl(
|
||||
)
|
||||
async def api_tinyurl(tinyurl_id: str):
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl:
|
||||
response = RedirectResponse(url=tinyurl.url)
|
||||
return response
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND, detail="unable to find tinyurl"
|
||||
)
|
||||
response = RedirectResponse(url=tinyurl.url)
|
||||
return response
|
||||
|
@ -11,24 +11,40 @@ from loguru import logger
|
||||
from .helpers import template_renderer
|
||||
|
||||
|
||||
class PaymentError(Exception):
|
||||
class LnbitsError(Exception):
|
||||
"""Base class for all exceptions in lnbits."""
|
||||
|
||||
|
||||
class PaymentError(LnbitsError):
|
||||
"""raised by fundingsource pay_invoice operations when an error occurs"""
|
||||
|
||||
def __init__(self, message: str, status: str = "pending"):
|
||||
self.message = message
|
||||
self.status = status
|
||||
|
||||
|
||||
class InvoiceError(Exception):
|
||||
class InvoiceError(LnbitsError):
|
||||
"""raised by fundingsource create_invoice operations when an error occurs"""
|
||||
|
||||
def __init__(self, message: str, status: str = "pending"):
|
||||
self.message = message
|
||||
self.status = status
|
||||
|
||||
|
||||
class NotFoundError(LnbitsError):
|
||||
"""
|
||||
Raised by crud operations when a resource is not found.
|
||||
Raises (401) error in api context.
|
||||
"""
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI):
|
||||
register_exception_handler(app)
|
||||
register_request_validation_exception_handler(app)
|
||||
register_http_exception_handler(app)
|
||||
register_payment_error_handler(app)
|
||||
register_invoice_error_handler(app)
|
||||
register_not_found_error_handler(app)
|
||||
|
||||
|
||||
def render_html_error(request: Request, exc: Exception) -> Optional[Response]:
|
||||
@ -115,3 +131,12 @@ def register_invoice_error_handler(app: FastAPI):
|
||||
status_code=520,
|
||||
content={"detail": exc.message, "status": exc.status},
|
||||
)
|
||||
|
||||
|
||||
def register_not_found_error_handler(app: FastAPI):
|
||||
@app.exception_handler(NotFoundError)
|
||||
async def notfound_error_handler(request: Request, exc: NotFoundError):
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"{exc!s}"},
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user