diff --git a/lnbits/app.py b/lnbits/app.py index 40a9723f0..cd700f5c9 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -79,10 +79,6 @@ def register_blueprints(app: QuartTrio) -> None: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") bp = getattr(ext_module, f"{ext.code}_ext") - @bp.teardown_request - async def after_request(exc): - await ext_module.db.close_session() - app.register_blueprint(bp, url_prefix=f"/{ext.code}") except Exception: raise ImportError( @@ -122,12 +118,6 @@ def register_request_hooks(app: QuartTrio): async def before_request(): g.nursery = app.nursery - @app.teardown_request - async def after_request(exc): - from lnbits.core import db - - await db.close_session() - @app.after_request async def set_secure_headers(response): secure_headers.quart(response) diff --git a/lnbits/commands.py b/lnbits/commands.py index b546765b7..2be04d127 100644 --- a/lnbits/commands.py +++ b/lnbits/commands.py @@ -53,46 +53,41 @@ def bundle_vendored(): async def migrate_databases(): """Creates the necessary databases if they don't exist already; or migrates them.""" - core_conn = await core_db.connect() - core_txn = await core_conn.begin() - - try: - rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() - except OperationalError: - # migration 3 wasn't ran - await core_migrations.m000_create_migrations_table(core_conn) - rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() - - current_versions = {row["db"]: row["version"] for row in rows} - matcher = re.compile(r"^m(\d\d\d)_") - - async def run_migration(db, migrations_module): - db_name = migrations_module.__name__.split(".")[-2] - for key, migrate in migrations_module.__dict__.items(): - match = match = matcher.match(key) - if match: - version = int(match.group(1)) - if version > current_versions.get(db_name, 0): - print(f"running migration {db_name}.{version}") - await migrate(db) - await core_conn.execute( - "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", - (db_name, version), - ) - - await run_migration(core_conn, core_migrations) - - for ext in get_valid_extensions(): + async with core_db.connect() as conn: try: - ext_migrations = importlib.import_module( - f"lnbits.extensions.{ext.code}.migrations" - ) - ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db - await run_migration(ext_db, ext_migrations) - except ImportError: - raise ImportError( - f"Please make sure that the extension `{ext.code}` has a migrations file." - ) + rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall() + except OperationalError: + # migration 3 wasn't ran + await core_migrations.m000_create_migrations_table(conn) + rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall() - await core_txn.commit() - await core_conn.close() + current_versions = {row["db"]: row["version"] for row in rows} + matcher = re.compile(r"^m(\d\d\d)_") + + async def run_migration(db, migrations_module): + db_name = migrations_module.__name__.split(".")[-2] + for key, migrate in migrations_module.__dict__.items(): + match = match = matcher.match(key) + if match: + version = int(match.group(1)) + if version > current_versions.get(db_name, 0): + print(f"running migration {db_name}.{version}") + await migrate(db) + await conn.execute( + "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", + (db_name, version), + ) + + await run_migration(conn, core_migrations) + + for ext in get_valid_extensions(): + try: + ext_migrations = importlib.import_module( + f"lnbits.extensions.{ext.code}.migrations" + ) + ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db + await run_migration(ext_db, ext_migrations) + except ImportError: + raise ImportError( + f"Please make sure that the extension `{ext.code}` has a migrations file." + ) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index b37bf375d..87d4972ae 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -4,6 +4,7 @@ from uuid import uuid4 from typing import List, Optional, Dict, Any from lnbits import bolt11 +from lnbits.db import Connection from lnbits.settings import DEFAULT_WALLET_NAME from . import db @@ -14,32 +15,36 @@ from .models import User, Wallet, Payment # -------- -async def create_account() -> User: +async def create_account(conn: Optional[Connection] = None) -> User: user_id = uuid4().hex - await db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) + await (conn or db).execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) - new_account = await get_account(user_id=user_id) + new_account = await get_account(user_id=user_id, conn=conn) assert new_account, "Newly created account couldn't be retrieved" return new_account -async def get_account(user_id: str) -> Optional[User]: - row = await db.fetchone( +async def get_account( + user_id: str, conn: Optional[Connection] = None +) -> Optional[User]: + row = await (conn or db).fetchone( "SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,) ) return User(**row) if row else None -async def get_user(user_id: str) -> Optional[User]: - user = await db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) +async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]: + user = await (conn or db).fetchone( + "SELECT id, email FROM accounts WHERE id = ?", (user_id,) + ) if user: - extensions = await db.fetchall( + extensions = await (conn or db).fetchall( "SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,) ) - wallets = await db.fetchall( + wallets = await (conn or db).fetchall( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -63,8 +68,10 @@ async def get_user(user_id: str) -> Optional[User]: ) -async def update_user_extension(*, user_id: str, extension: str, active: int) -> None: - await db.execute( +async def update_user_extension( + *, user_id: str, extension: str, active: int, conn: Optional[Connection] = None +) -> None: + await (conn or db).execute( """ INSERT OR REPLACE INTO extensions (user, extension, active) VALUES (?, ?, ?) @@ -77,9 +84,14 @@ async def update_user_extension(*, user_id: str, extension: str, active: int) -> # ------- -async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: +async def create_wallet( + *, + user_id: str, + wallet_name: Optional[str] = None, + conn: Optional[Connection] = None, +) -> Wallet: wallet_id = uuid4().hex - await db.execute( + await (conn or db).execute( """ INSERT INTO wallets (id, name, user, adminkey, inkey) VALUES (?, ?, ?, ?, ?) @@ -93,14 +105,16 @@ async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> W ), ) - new_wallet = await get_wallet(wallet_id=wallet_id) + new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn) assert new_wallet, "Newly created wallet couldn't be retrieved" return new_wallet -async def delete_wallet(*, user_id: str, wallet_id: str) -> None: - await db.execute( +async def delete_wallet( + *, user_id: str, wallet_id: str, conn: Optional[Connection] = None +) -> None: + await (conn or db).execute( """ UPDATE wallets AS w SET @@ -113,8 +127,10 @@ async def delete_wallet(*, user_id: str, wallet_id: str) -> None: ) -async def get_wallet(wallet_id: str) -> Optional[Wallet]: - row = await db.fetchone( +async def get_wallet( + wallet_id: str, conn: Optional[Connection] = None +) -> Optional[Wallet]: + row = await (conn or db).fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -126,8 +142,10 @@ async def get_wallet(wallet_id: str) -> Optional[Wallet]: return Wallet(**row) if row else None -async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: - row = await db.fetchone( +async def get_wallet_for_key( + key: str, key_type: str = "invoice", conn: Optional[Connection] = None +) -> Optional[Wallet]: + row = await (conn or db).fetchone( """ SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat FROM wallets @@ -149,8 +167,10 @@ async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wa # --------------- -async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]: - row = await db.fetchone( +async def get_standalone_payment( + checking_id_or_hash: str, conn: Optional[Connection] = None +) -> Optional[Payment]: + row = await (conn or db).fetchone( """ SELECT * FROM apipayments @@ -163,8 +183,10 @@ async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]: return Payment.from_row(row) if row else None -async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: - row = await db.fetchone( +async def get_wallet_payment( + wallet_id: str, payment_hash: str, conn: Optional[Connection] = None +) -> Optional[Payment]: + row = await (conn or db).fetchone( """ SELECT * FROM apipayments @@ -185,6 +207,7 @@ async def get_payments( incoming: bool = False, since: Optional[int] = None, exclude_uncheckable: bool = False, + conn: Optional[Connection] = None, ) -> List[Payment]: """ Filters payments to be returned by complete | pending | outgoing | incoming. @@ -227,7 +250,7 @@ async def get_payments( if clause: where = f"WHERE {' AND '.join(clause)}" - rows = await db.fetchall( + rows = await (conn or db).fetchall( f""" SELECT * FROM apipayments @@ -240,8 +263,10 @@ async def get_payments( return [Payment.from_row(row) for row in rows] -async def delete_expired_invoices() -> None: - rows = await db.fetchall( +async def delete_expired_invoices( + conn: Optional[Connection] = None, +) -> None: + rows = await (conn or db).fetchall( """ SELECT bolt11 FROM apipayments @@ -258,7 +283,7 @@ async def delete_expired_invoices() -> None: if expiration_date > datetime.datetime.utcnow(): continue - await db.execute( + await (conn or db).execute( """ DELETE FROM apipayments WHERE pending = 1 AND hash = ? @@ -284,8 +309,9 @@ async def create_payment( pending: bool = True, extra: Optional[Dict] = None, webhook: Optional[str] = None, + conn: Optional[Connection] = None, ) -> Payment: - await db.execute( + await (conn or db).execute( """ INSERT INTO apipayments (wallet, checking_id, bolt11, hash, preimage, @@ -309,14 +335,18 @@ async def create_payment( ), ) - new_payment = await get_wallet_payment(wallet_id, payment_hash) + new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn) assert new_payment, "Newly created payment couldn't be retrieved" return new_payment -async def update_payment_status(checking_id: str, pending: bool) -> None: - await db.execute( +async def update_payment_status( + checking_id: str, + pending: bool, + conn: Optional[Connection] = None, +) -> None: + await (conn or db).execute( "UPDATE apipayments SET pending = ? WHERE checking_id = ?", ( int(pending), @@ -325,12 +355,20 @@ async def update_payment_status(checking_id: str, pending: bool) -> None: ) -async def delete_payment(checking_id: str) -> None: - await db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) +async def delete_payment( + checking_id: str, + conn: Optional[Connection] = None, +) -> None: + await (conn or db).execute( + "DELETE FROM apipayments WHERE checking_id = ?", (checking_id,) + ) -async def check_internal(payment_hash: str) -> Optional[str]: - row = await db.fetchone( +async def check_internal( + payment_hash: str, + conn: Optional[Connection] = None, +) -> Optional[str]: + row = await (conn or db).fetchone( """ SELECT checking_id FROM apipayments WHERE hash = ? AND pending AND amount > 0 diff --git a/lnbits/core/services.py b/lnbits/core/services.py index b5c3d9fa2..d623b1183 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -13,6 +13,7 @@ except ImportError: # pragma: nocover from typing_extensions import TypedDict from lnbits import bolt11 +from lnbits.db import Connection from lnbits.helpers import urlsafe_short_hash from lnbits.settings import WALLET from lnbits.wallets.base import PaymentStatus, PaymentResponse @@ -36,8 +37,8 @@ async def create_invoice( description_hash: Optional[bytes] = None, extra: Optional[Dict] = None, webhook: Optional[str] = None, + conn: Optional[Connection] = None, ) -> Tuple[str, str]: - await db.begin() invoice_memo = None if description_hash else memo storeable_memo = memo @@ -59,9 +60,9 @@ async def create_invoice( memo=storeable_memo, extra=extra, webhook=webhook, + conn=conn, ) - await db.commit() return invoice.payment_hash, payment_request @@ -72,102 +73,114 @@ async def pay_invoice( max_sat: Optional[int] = None, extra: Optional[Dict] = None, description: str = "", + conn: Optional[Connection] = None, ) -> str: - await db.begin() - temp_id = f"temp_{urlsafe_short_hash()}" - internal_id = f"internal_{urlsafe_short_hash()}" + async with (db.reuse_conn(conn) if conn else db.connect()) as conn: + temp_id = f"temp_{urlsafe_short_hash()}" + internal_id = f"internal_{urlsafe_short_hash()}" - invoice = bolt11.decode(payment_request) - if invoice.amount_msat == 0: - raise ValueError("Amountless invoices not supported.") - if max_sat and invoice.amount_msat > max_sat * 1000: - raise ValueError("Amount in invoice is too high.") + invoice = bolt11.decode(payment_request) + if invoice.amount_msat == 0: + raise ValueError("Amountless invoices not supported.") + if max_sat and invoice.amount_msat > max_sat * 1000: + raise ValueError("Amount in invoice is too high.") - # put all parameters that don't change here - PaymentKwargs = TypedDict( - "PaymentKwargs", - { - "wallet_id": str, - "payment_request": str, - "payment_hash": str, - "amount": int, - "memo": str, - "extra": Optional[Dict], - }, - ) - payment_kwargs: PaymentKwargs = dict( - wallet_id=wallet_id, - payment_request=payment_request, - payment_hash=invoice.payment_hash, - amount=-invoice.amount_msat, - memo=description or invoice.description or "", - extra=extra, - ) - - # check_internal() returns the checking_id of the invoice we're waiting for - internal_checking_id = await check_internal(invoice.payment_hash) - if internal_checking_id: - # create a new payment from this wallet - await create_payment( - checking_id=internal_id, fee=0, pending=False, **payment_kwargs + # put all parameters that don't change here + PaymentKwargs = TypedDict( + "PaymentKwargs", + { + "wallet_id": str, + "payment_request": str, + "payment_hash": str, + "amount": int, + "memo": str, + "extra": Optional[Dict], + }, + ) + payment_kwargs: PaymentKwargs = dict( + wallet_id=wallet_id, + payment_request=payment_request, + payment_hash=invoice.payment_hash, + amount=-invoice.amount_msat, + memo=description or invoice.description or "", + extra=extra, ) - else: - # create a temporary payment here so we can check if - # the balance is enough in the next step - fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) - await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) - # do the balance check - wallet = await get_wallet(wallet_id) - assert wallet - if wallet.balance_msat < 0: - await db.rollback() - raise PermissionError("Insufficient balance.") - else: - await db.commit() - await db.begin() - - if internal_checking_id: - # mark the invoice from the other side as not pending anymore - # so the other side only has access to his new money when we are sure - # the payer has enough to deduct from - await update_payment_status(checking_id=internal_checking_id, pending=False) - - # notify receiver asynchronously - from lnbits.tasks import internal_invoice_paid - - await internal_invoice_paid.send(internal_checking_id) - else: - # actually pay the external invoice - payment: PaymentResponse = await WALLET.pay_invoice(payment_request) - if payment.checking_id: + # check_internal() returns the checking_id of the invoice we're waiting for + internal_checking_id = await check_internal(invoice.payment_hash, conn=conn) + if internal_checking_id: + # create a new payment from this wallet await create_payment( - checking_id=payment.checking_id, - fee=payment.fee_msat, - preimage=payment.preimage, - pending=payment.ok == None, + checking_id=internal_id, + fee=0, + pending=False, + conn=conn, **payment_kwargs, ) - await delete_payment(temp_id) - await db.commit() else: - await delete_payment(temp_id) - await db.commit() - raise Exception( - payment.error_message or "Failed to pay_invoice on backend." + # create a temporary payment here so we can check if + # the balance is enough in the next step + fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) + await create_payment( + checking_id=temp_id, + fee=-fee_reserve, + conn=conn, + **payment_kwargs, ) - return invoice.payment_hash + # do the balance check + wallet = await get_wallet(wallet_id, conn=conn) + assert wallet + if wallet.balance_msat < 0: + raise PermissionError("Insufficient balance.") + + if internal_checking_id: + # mark the invoice from the other side as not pending anymore + # so the other side only has access to his new money when we are sure + # the payer has enough to deduct from + await update_payment_status( + checking_id=internal_checking_id, + pending=False, + conn=conn, + ) + + # notify receiver asynchronously + from lnbits.tasks import internal_invoice_paid + + await internal_invoice_paid.send(internal_checking_id) + else: + # actually pay the external invoice + payment: PaymentResponse = await WALLET.pay_invoice(payment_request) + if payment.checking_id: + await create_payment( + checking_id=payment.checking_id, + fee=payment.fee_msat, + preimage=payment.preimage, + pending=payment.ok == None, + conn=conn, + **payment_kwargs, + ) + await delete_payment(temp_id, conn=conn) + else: + raise Exception( + payment.error_message or "Failed to pay_invoice on backend." + ) + + return invoice.payment_hash async def redeem_lnurl_withdraw( - wallet_id: str, res: LnurlWithdrawResponse, memo: Optional[str] = None + wallet_id: str, + res: LnurlWithdrawResponse, + memo: Optional[str] = None, + conn: Optional[Connection] = None, ) -> None: _, payment_request = await create_invoice( wallet_id=wallet_id, amount=res.max_sats, memo=memo or res.default_description or "", extra={"tag": "lnurlwallet"}, + conn=conn, ) async with httpx.AsyncClient() as client: @@ -180,7 +193,10 @@ async def redeem_lnurl_withdraw( ) -async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]: +async def perform_lnurlauth( + callback: str, + conn: Optional[Connection] = None, +) -> Optional[LnurlErrorResponse]: cb = urlparse(callback) k1 = unhexlify(parse_qs(cb.query)["k1"][0]) @@ -251,8 +267,12 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]: ) -async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: - payment = await get_wallet_payment(wallet_id, payment_hash) +async def check_invoice_status( + wallet_id: str, + payment_hash: str, + conn: Optional[Connection] = None, +) -> PaymentStatus: + payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn) if not payment: return PaymentStatus(None) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 7290aba24..2d1b99a9b 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -66,7 +66,7 @@ async def api_payments_create_invoice(): description_hash = b"" memo = g.data["memo"] - try: + async with db.connect() as conn: payment_hash, payment_request = await create_invoice( wallet_id=g.wallet.id, amount=g.data["amount"], @@ -74,12 +74,8 @@ async def api_payments_create_invoice(): description_hash=description_hash, extra=g.data.get("extra"), webhook=g.data.get("webhook"), + conn=conn, ) - except Exception as exc: - await db.rollback() - raise exc - - await db.commit() invoice = bolt11.decode(payment_request) @@ -124,14 +120,14 @@ async def api_payments_create_invoice(): async def api_payments_pay_invoice(): try: payment_hash = await pay_invoice( - wallet_id=g.wallet.id, payment_request=g.data["bolt11"] + wallet_id=g.wallet.id, + payment_request=g.data["bolt11"], ) except ValueError as e: return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST except PermissionError as e: return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN except Exception as exc: - await db.rollback() raise exc return ( @@ -217,23 +213,19 @@ async def api_payments_pay_lnurl(): HTTPStatus.BAD_REQUEST, ) - try: - extra = {} + extra = {} - if params.get("successAction"): - extra["success_action"] = params["successAction"] - if g.data["comment"]: - extra["comment"] = g.data["comment"] + if params.get("successAction"): + extra["success_action"] = params["successAction"] + if g.data["comment"]: + extra["comment"] = g.data["comment"] - payment_hash = await pay_invoice( - wallet_id=g.wallet.id, - payment_request=params["pr"], - description=g.data.get("description", ""), - extra=extra, - ) - except Exception as exc: - await db.rollback() - raise exc + payment_hash = await pay_invoice( + wallet_id=g.wallet.id, + payment_request=params["pr"], + description=g.data.get("description", ""), + extra=extra, + ) return ( jsonify( diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 545b26775..eb719d7e5 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -13,11 +13,10 @@ from quart import ( ) from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl # type: ignore -from lnbits.core import core_app +from lnbits.core import core_app, db from lnbits.decorators import check_user_exists, validate_uuids from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE -from .. import db from ..crud import ( create_account, get_user, @@ -152,10 +151,10 @@ async def lnurlwallet(): HTTPStatus.INTERNAL_SERVER_ERROR, ) - account = await create_account() - user = await get_user(account.id) - wallet = await create_wallet(user_id=user.id) - await db.commit() + async with db.connect() as conn: + account = await create_account(conn=conn) + user = await get_user(account.id, conn=conn) + wallet = await create_wallet(user_id=user.id, conn=conn) g.nursery.start_soon( redeem_lnurl_withdraw, diff --git a/lnbits/db.py b/lnbits/db.py index 6a44b6c25..82bda2141 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -1,57 +1,54 @@ import os -from typing import Tuple, Optional, Any -from sqlalchemy_aio import TRIO_STRATEGY # type: ignore +import trio +from contextlib import asynccontextmanager from sqlalchemy import create_engine # type: ignore -from quart import g +from sqlalchemy_aio import TRIO_STRATEGY # type: ignore +from sqlalchemy_aio.base import AsyncConnection from .settings import LNBITS_DATA_FOLDER +class Connection: + def __init__(self, conn: AsyncConnection): + self.conn = conn + + async def fetchall(self, query: str, values: tuple = ()) -> list: + result = await self.conn.execute(query, values) + return await result.fetchall() + + async def fetchone(self, query: str, values: tuple = ()): + result = await self.conn.execute(query, values) + row = await result.fetchone() + await result.close() + return row + + async def execute(self, query: str, values: tuple = ()): + return await self.conn.execute(query, values) + + class Database: def __init__(self, db_name: str): self.db_name = db_name db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) + self.lock = trio.StrictFIFOLock() - def connect(self): - return self.engine.connect() - - def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]: + @asynccontextmanager + async def connect(self): + await self.lock.acquire() try: - return getattr(g, f"{self.db_name}_conn", None), getattr( - g, f"{self.db_name}_txn", None - ) - except RuntimeError: - return None, None - - async def begin(self): - conn, _ = self.session_connection() - if conn: - return - - conn = await self.engine.connect() - setattr(g, f"{self.db_name}_conn", conn) - txn = await conn.begin() - setattr(g, f"{self.db_name}_txn", txn) + async with self.engine.connect() as conn: + async with conn.begin(): + yield Connection(conn) + finally: + self.lock.release() async def fetchall(self, query: str, values: tuple = ()) -> list: - conn, _ = self.session_connection() - if conn: - result = await conn.execute(query, values) - return await result.fetchall() - async with self.connect() as conn: result = await conn.execute(query, values) return await result.fetchall() async def fetchone(self, query: str, values: tuple = ()): - conn, _ = self.session_connection() - if conn: - result = await conn.execute(query, values) - row = await result.fetchone() - await result.close() - return row - async with self.connect() as conn: result = await conn.execute(query, values) row = await result.fetchone() @@ -59,29 +56,9 @@ class Database: return row async def execute(self, query: str, values: tuple = ()): - conn, _ = self.session_connection() - if conn: - return await conn.execute(query, values) - async with self.connect() as conn: return await conn.execute(query, values) - async def commit(self): - conn, txn = self.session_connection() - if conn and txn: - await txn.commit() - await self.close_session() - - async def rollback(self): - conn, txn = self.session_connection() - if conn and txn: - await txn.rollback() - await self.close_session() - - async def close_session(self): - conn, txn = self.session_connection() - if conn and txn: - await txn.close() - await conn.close() - delattr(g, f"{self.db_name}_conn") - delattr(g, f"{self.db_name}_txn") + @asynccontextmanager + async def reuse_conn(self, conn: Connection): + yield conn diff --git a/lnbits/extensions/bleskomat/lnurl_api.py b/lnbits/extensions/bleskomat/lnurl_api.py index e2743292a..086562d1c 100644 --- a/lnbits/extensions/bleskomat/lnurl_api.py +++ b/lnbits/extensions/bleskomat/lnurl_api.py @@ -1,6 +1,6 @@ import json import math -from quart import g, jsonify, request +from quart import jsonify, request from http import HTTPStatus import traceback @@ -29,7 +29,6 @@ from .helpers import ( # Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs. @bleskomat_ext.route("/u", methods=["GET"]) async def api_bleskomat_lnurl(): - try: query = request.args.to_dict() @@ -125,7 +124,7 @@ async def api_bleskomat_lnurl(): except LnurlHttpError as e: return jsonify({"status": "ERROR", "reason": str(e)}), e.http_status - except Exception as e: + except Exception: traceback.print_exc() return ( jsonify({"status": "ERROR", "reason": "Unexpected error"}), diff --git a/lnbits/extensions/bleskomat/models.py b/lnbits/extensions/bleskomat/models.py index 0ec2673bc..d014f25ad 100644 --- a/lnbits/extensions/bleskomat/models.py +++ b/lnbits/extensions/bleskomat/models.py @@ -61,7 +61,7 @@ class BleskomatLnurl(NamedTuple): raise LnurlValidationError("Multiple payment requests not supported") try: invoice = bolt11.decode(pr) - except ValueError as e: + except ValueError: raise LnurlValidationError( 'Invalid parameter ("pr"): Lightning payment request expected' ) @@ -79,14 +79,11 @@ class BleskomatLnurl(NamedTuple): async def execute_action(self, query: Dict[str, str]): self.validate_action(query) used = False - if self.initial_uses > 0: - await db.commit() - await db.begin() - used = await self.use() - if not used: - await db.rollback() - raise LnurlValidationError("Maximum number of uses already reached") - try: + async with db.connect() as conn: + if self.initial_uses > 0: + used = await self.use(conn) + if not used: + raise LnurlValidationError("Maximum number of uses already reached") tag = self.tag if tag == "withdrawRequest": payment_hash = await pay_invoice( @@ -95,16 +92,10 @@ class BleskomatLnurl(NamedTuple): ) if not payment_hash: raise LnurlValidationError("Failed to pay invoice") - except Exception as e: - if used: - await db.rollback() - raise e - if used: - await db.commit() - async def use(self) -> bool: + async def use(self, conn) -> bool: now = int(time.time()) - result = await db.execute( + result = await conn.execute( """ UPDATE bleskomat_lnurls SET remaining_uses = remaining_uses - 1, updated_time = ? diff --git a/lnbits/extensions/bleskomat/views_api.py b/lnbits/extensions/bleskomat/views_api.py index 7e9a1f1a7..2971b0669 100644 --- a/lnbits/extensions/bleskomat/views_api.py +++ b/lnbits/extensions/bleskomat/views_api.py @@ -70,11 +70,10 @@ async def api_bleskomat_retrieve(bleskomat_id): } ) async def api_bleskomat_create_or_update(bleskomat_id=None): - try: fiat_currency = g.data["fiat_currency"] exchange_rate_provider = g.data["exchange_rate_provider"] - rate = await fetch_fiat_exchange_rate( + await fetch_fiat_exchange_rate( currency=fiat_currency, provider=exchange_rate_provider ) except Exception as e: