fix sqlite database locked issues by using an async lock on the database and requiring explicit transaction control (or each command will be its own transaction).

This commit is contained in:
fiatjaf
2021-03-26 19:10:30 -03:00
parent 9cc7052920
commit 85011d23c3
10 changed files with 276 additions and 276 deletions

View File

@@ -79,10 +79,6 @@ def register_blueprints(app: QuartTrio) -> None:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
bp = getattr(ext_module, f"{ext.code}_ext")
@bp.teardown_request
async def after_request(exc):
await ext_module.db.close_session()
app.register_blueprint(bp, url_prefix=f"/{ext.code}")
except Exception:
raise ImportError(
@@ -122,12 +118,6 @@ def register_request_hooks(app: QuartTrio):
async def before_request():
g.nursery = app.nursery
@app.teardown_request
async def after_request(exc):
from lnbits.core import db
await db.close_session()
@app.after_request
async def set_secure_headers(response):
secure_headers.quart(response)

View File

@@ -53,46 +53,41 @@ def bundle_vendored():
async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
core_conn = await core_db.connect()
core_txn = await core_conn.begin()
try:
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall()
except OperationalError:
# migration 3 wasn't ran
await core_migrations.m000_create_migrations_table(core_conn)
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall()
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
async def run_migration(db, migrations_module):
db_name = migrations_module.__name__.split(".")[-2]
for key, migrate in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}")
await migrate(db)
await core_conn.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)",
(db_name, version),
)
await run_migration(core_conn, core_migrations)
for ext in get_valid_extensions():
async with core_db.connect() as conn:
try:
ext_migrations = importlib.import_module(
f"lnbits.extensions.{ext.code}.migrations"
)
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
await run_migration(ext_db, ext_migrations)
except ImportError:
raise ImportError(
f"Please make sure that the extension `{ext.code}` has a migrations file."
)
rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
except OperationalError:
# migration 3 wasn't ran
await core_migrations.m000_create_migrations_table(conn)
rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
await core_txn.commit()
await core_conn.close()
current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_")
async def run_migration(db, migrations_module):
db_name = migrations_module.__name__.split(".")[-2]
for key, migrate in migrations_module.__dict__.items():
match = match = matcher.match(key)
if match:
version = int(match.group(1))
if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}")
await migrate(db)
await conn.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)",
(db_name, version),
)
await run_migration(conn, core_migrations)
for ext in get_valid_extensions():
try:
ext_migrations = importlib.import_module(
f"lnbits.extensions.{ext.code}.migrations"
)
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
await run_migration(ext_db, ext_migrations)
except ImportError:
raise ImportError(
f"Please make sure that the extension `{ext.code}` has a migrations file."
)

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
from typing import List, Optional, Dict, Any
from lnbits import bolt11
from lnbits.db import Connection
from lnbits.settings import DEFAULT_WALLET_NAME
from . import db
@@ -14,32 +15,36 @@ from .models import User, Wallet, Payment
# --------
async def create_account() -> User:
async def create_account(conn: Optional[Connection] = None) -> User:
user_id = uuid4().hex
await db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
await (conn or db).execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = await get_account(user_id=user_id)
new_account = await get_account(user_id=user_id, conn=conn)
assert new_account, "Newly created account couldn't be retrieved"
return new_account
async def get_account(user_id: str) -> Optional[User]:
row = await db.fetchone(
async def get_account(
user_id: str, conn: Optional[Connection] = None
) -> Optional[User]:
row = await (conn or db).fetchone(
"SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)
)
return User(**row) if row else None
async def get_user(user_id: str) -> Optional[User]:
user = await db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
user = await (conn or db).fetchone(
"SELECT id, email FROM accounts WHERE id = ?", (user_id,)
)
if user:
extensions = await db.fetchall(
extensions = await (conn or db).fetchall(
"SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)
)
wallets = await db.fetchall(
wallets = await (conn or db).fetchall(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@@ -63,8 +68,10 @@ async def get_user(user_id: str) -> Optional[User]:
)
async def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
await db.execute(
async def update_user_extension(
*, user_id: str, extension: str, active: int, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?)
@@ -77,9 +84,14 @@ async def update_user_extension(*, user_id: str, extension: str, active: int) ->
# -------
async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
async def create_wallet(
*,
user_id: str,
wallet_name: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Wallet:
wallet_id = uuid4().hex
await db.execute(
await (conn or db).execute(
"""
INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?)
@@ -93,14 +105,16 @@ async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> W
),
)
new_wallet = await get_wallet(wallet_id=wallet_id)
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet
async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
await db.execute(
async def delete_wallet(
*, user_id: str, wallet_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
UPDATE wallets AS w
SET
@@ -113,8 +127,10 @@ async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
)
async def get_wallet(wallet_id: str) -> Optional[Wallet]:
row = await db.fetchone(
async def get_wallet(
wallet_id: str, conn: Optional[Connection] = None
) -> Optional[Wallet]:
row = await (conn or db).fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@@ -126,8 +142,10 @@ async def get_wallet(wallet_id: str) -> Optional[Wallet]:
return Wallet(**row) if row else None
async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
row = await db.fetchone(
async def get_wallet_for_key(
key: str, key_type: str = "invoice", conn: Optional[Connection] = None
) -> Optional[Wallet]:
row = await (conn or db).fetchone(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets
@@ -149,8 +167,10 @@ async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wa
# ---------------
async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]:
row = await db.fetchone(
async def get_standalone_payment(
checking_id_or_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
row = await (conn or db).fetchone(
"""
SELECT *
FROM apipayments
@@ -163,8 +183,10 @@ async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]:
return Payment.from_row(row) if row else None
async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
row = await db.fetchone(
async def get_wallet_payment(
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
row = await (conn or db).fetchone(
"""
SELECT *
FROM apipayments
@@ -185,6 +207,7 @@ async def get_payments(
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
conn: Optional[Connection] = None,
) -> List[Payment]:
"""
Filters payments to be returned by complete | pending | outgoing | incoming.
@@ -227,7 +250,7 @@ async def get_payments(
if clause:
where = f"WHERE {' AND '.join(clause)}"
rows = await db.fetchall(
rows = await (conn or db).fetchall(
f"""
SELECT *
FROM apipayments
@@ -240,8 +263,10 @@ async def get_payments(
return [Payment.from_row(row) for row in rows]
async def delete_expired_invoices() -> None:
rows = await db.fetchall(
async def delete_expired_invoices(
conn: Optional[Connection] = None,
) -> None:
rows = await (conn or db).fetchall(
"""
SELECT bolt11
FROM apipayments
@@ -258,7 +283,7 @@ async def delete_expired_invoices() -> None:
if expiration_date > datetime.datetime.utcnow():
continue
await db.execute(
await (conn or db).execute(
"""
DELETE FROM apipayments
WHERE pending = 1 AND hash = ?
@@ -284,8 +309,9 @@ async def create_payment(
pending: bool = True,
extra: Optional[Dict] = None,
webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Payment:
await db.execute(
await (conn or db).execute(
"""
INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage,
@@ -309,14 +335,18 @@ async def create_payment(
),
)
new_payment = await get_wallet_payment(wallet_id, payment_hash)
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
async def update_payment_status(checking_id: str, pending: bool) -> None:
await db.execute(
async def update_payment_status(
checking_id: str,
pending: bool,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?",
(
int(pending),
@@ -325,12 +355,20 @@ async def update_payment_status(checking_id: str, pending: bool) -> None:
)
async def delete_payment(checking_id: str) -> None:
await db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
async def delete_payment(
checking_id: str,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)
)
async def check_internal(payment_hash: str) -> Optional[str]:
row = await db.fetchone(
async def check_internal(
payment_hash: str,
conn: Optional[Connection] = None,
) -> Optional[str]:
row = await (conn or db).fetchone(
"""
SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0

View File

@@ -13,6 +13,7 @@ except ImportError: # pragma: nocover
from typing_extensions import TypedDict
from lnbits import bolt11
from lnbits.db import Connection
from lnbits.helpers import urlsafe_short_hash
from lnbits.settings import WALLET
from lnbits.wallets.base import PaymentStatus, PaymentResponse
@@ -36,8 +37,8 @@ async def create_invoice(
description_hash: Optional[bytes] = None,
extra: Optional[Dict] = None,
webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Tuple[str, str]:
await db.begin()
invoice_memo = None if description_hash else memo
storeable_memo = memo
@@ -59,9 +60,9 @@ async def create_invoice(
memo=storeable_memo,
extra=extra,
webhook=webhook,
conn=conn,
)
await db.commit()
return invoice.payment_hash, payment_request
@@ -72,102 +73,114 @@ async def pay_invoice(
max_sat: Optional[int] = None,
extra: Optional[Dict] = None,
description: str = "",
conn: Optional[Connection] = None,
) -> str:
await db.begin()
temp_id = f"temp_{urlsafe_short_hash()}"
internal_id = f"internal_{urlsafe_short_hash()}"
async with (db.reuse_conn(conn) if conn else db.connect()) as conn:
temp_id = f"temp_{urlsafe_short_hash()}"
internal_id = f"internal_{urlsafe_short_hash()}"
invoice = bolt11.decode(payment_request)
if invoice.amount_msat == 0:
raise ValueError("Amountless invoices not supported.")
if max_sat and invoice.amount_msat > max_sat * 1000:
raise ValueError("Amount in invoice is too high.")
invoice = bolt11.decode(payment_request)
if invoice.amount_msat == 0:
raise ValueError("Amountless invoices not supported.")
if max_sat and invoice.amount_msat > max_sat * 1000:
raise ValueError("Amount in invoice is too high.")
# put all parameters that don't change here
PaymentKwargs = TypedDict(
"PaymentKwargs",
{
"wallet_id": str,
"payment_request": str,
"payment_hash": str,
"amount": int,
"memo": str,
"extra": Optional[Dict],
},
)
payment_kwargs: PaymentKwargs = dict(
wallet_id=wallet_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=-invoice.amount_msat,
memo=description or invoice.description or "",
extra=extra,
)
# check_internal() returns the checking_id of the invoice we're waiting for
internal_checking_id = await check_internal(invoice.payment_hash)
if internal_checking_id:
# create a new payment from this wallet
await create_payment(
checking_id=internal_id, fee=0, pending=False, **payment_kwargs
# put all parameters that don't change here
PaymentKwargs = TypedDict(
"PaymentKwargs",
{
"wallet_id": str,
"payment_request": str,
"payment_hash": str,
"amount": int,
"memo": str,
"extra": Optional[Dict],
},
)
payment_kwargs: PaymentKwargs = dict(
wallet_id=wallet_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=-invoice.amount_msat,
memo=description or invoice.description or "",
extra=extra,
)
else:
# create a temporary payment here so we can check if
# the balance is enough in the next step
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
# do the balance check
wallet = await get_wallet(wallet_id)
assert wallet
if wallet.balance_msat < 0:
await db.rollback()
raise PermissionError("Insufficient balance.")
else:
await db.commit()
await db.begin()
if internal_checking_id:
# mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure
# the payer has enough to deduct from
await update_payment_status(checking_id=internal_checking_id, pending=False)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_paid
await internal_invoice_paid.send(internal_checking_id)
else:
# actually pay the external invoice
payment: PaymentResponse = await WALLET.pay_invoice(payment_request)
if payment.checking_id:
# check_internal() returns the checking_id of the invoice we're waiting for
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
if internal_checking_id:
# create a new payment from this wallet
await create_payment(
checking_id=payment.checking_id,
fee=payment.fee_msat,
preimage=payment.preimage,
pending=payment.ok == None,
checking_id=internal_id,
fee=0,
pending=False,
conn=conn,
**payment_kwargs,
)
await delete_payment(temp_id)
await db.commit()
else:
await delete_payment(temp_id)
await db.commit()
raise Exception(
payment.error_message or "Failed to pay_invoice on backend."
# create a temporary payment here so we can check if
# the balance is enough in the next step
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
await create_payment(
checking_id=temp_id,
fee=-fee_reserve,
conn=conn,
**payment_kwargs,
)
return invoice.payment_hash
# do the balance check
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet
if wallet.balance_msat < 0:
raise PermissionError("Insufficient balance.")
if internal_checking_id:
# mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure
# the payer has enough to deduct from
await update_payment_status(
checking_id=internal_checking_id,
pending=False,
conn=conn,
)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_paid
await internal_invoice_paid.send(internal_checking_id)
else:
# actually pay the external invoice
payment: PaymentResponse = await WALLET.pay_invoice(payment_request)
if payment.checking_id:
await create_payment(
checking_id=payment.checking_id,
fee=payment.fee_msat,
preimage=payment.preimage,
pending=payment.ok == None,
conn=conn,
**payment_kwargs,
)
await delete_payment(temp_id, conn=conn)
else:
raise Exception(
payment.error_message or "Failed to pay_invoice on backend."
)
return invoice.payment_hash
async def redeem_lnurl_withdraw(
wallet_id: str, res: LnurlWithdrawResponse, memo: Optional[str] = None
wallet_id: str,
res: LnurlWithdrawResponse,
memo: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
_, payment_request = await create_invoice(
wallet_id=wallet_id,
amount=res.max_sats,
memo=memo or res.default_description or "",
extra={"tag": "lnurlwallet"},
conn=conn,
)
async with httpx.AsyncClient() as client:
@@ -180,7 +193,10 @@ async def redeem_lnurl_withdraw(
)
async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]:
async def perform_lnurlauth(
callback: str,
conn: Optional[Connection] = None,
) -> Optional[LnurlErrorResponse]:
cb = urlparse(callback)
k1 = unhexlify(parse_qs(cb.query)["k1"][0])
@@ -251,8 +267,12 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]:
)
async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus:
payment = await get_wallet_payment(wallet_id, payment_hash)
async def check_invoice_status(
wallet_id: str,
payment_hash: str,
conn: Optional[Connection] = None,
) -> PaymentStatus:
payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
if not payment:
return PaymentStatus(None)

View File

@@ -66,7 +66,7 @@ async def api_payments_create_invoice():
description_hash = b""
memo = g.data["memo"]
try:
async with db.connect() as conn:
payment_hash, payment_request = await create_invoice(
wallet_id=g.wallet.id,
amount=g.data["amount"],
@@ -74,12 +74,8 @@ async def api_payments_create_invoice():
description_hash=description_hash,
extra=g.data.get("extra"),
webhook=g.data.get("webhook"),
conn=conn,
)
except Exception as exc:
await db.rollback()
raise exc
await db.commit()
invoice = bolt11.decode(payment_request)
@@ -124,14 +120,14 @@ async def api_payments_create_invoice():
async def api_payments_pay_invoice():
try:
payment_hash = await pay_invoice(
wallet_id=g.wallet.id, payment_request=g.data["bolt11"]
wallet_id=g.wallet.id,
payment_request=g.data["bolt11"],
)
except ValueError as e:
return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST
except PermissionError as e:
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as exc:
await db.rollback()
raise exc
return (
@@ -217,23 +213,19 @@ async def api_payments_pay_lnurl():
HTTPStatus.BAD_REQUEST,
)
try:
extra = {}
extra = {}
if params.get("successAction"):
extra["success_action"] = params["successAction"]
if g.data["comment"]:
extra["comment"] = g.data["comment"]
if params.get("successAction"):
extra["success_action"] = params["successAction"]
if g.data["comment"]:
extra["comment"] = g.data["comment"]
payment_hash = await pay_invoice(
wallet_id=g.wallet.id,
payment_request=params["pr"],
description=g.data.get("description", ""),
extra=extra,
)
except Exception as exc:
await db.rollback()
raise exc
payment_hash = await pay_invoice(
wallet_id=g.wallet.id,
payment_request=params["pr"],
description=g.data.get("description", ""),
extra=extra,
)
return (
jsonify(

View File

@@ -13,11 +13,10 @@ from quart import (
)
from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl # type: ignore
from lnbits.core import core_app
from lnbits.core import core_app, db
from lnbits.decorators import check_user_exists, validate_uuids
from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE
from .. import db
from ..crud import (
create_account,
get_user,
@@ -152,10 +151,10 @@ async def lnurlwallet():
HTTPStatus.INTERNAL_SERVER_ERROR,
)
account = await create_account()
user = await get_user(account.id)
wallet = await create_wallet(user_id=user.id)
await db.commit()
async with db.connect() as conn:
account = await create_account(conn=conn)
user = await get_user(account.id, conn=conn)
wallet = await create_wallet(user_id=user.id, conn=conn)
g.nursery.start_soon(
redeem_lnurl_withdraw,

View File

@@ -1,57 +1,54 @@
import os
from typing import Tuple, Optional, Any
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
import trio
from contextlib import asynccontextmanager
from sqlalchemy import create_engine # type: ignore
from quart import g
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
from sqlalchemy_aio.base import AsyncConnection
from .settings import LNBITS_DATA_FOLDER
class Connection:
def __init__(self, conn: AsyncConnection):
self.conn = conn
async def fetchall(self, query: str, values: tuple = ()) -> list:
result = await self.conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
result = await self.conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
return await self.conn.execute(query, values)
class Database:
def __init__(self, db_name: str):
self.db_name = db_name
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
self.lock = trio.StrictFIFOLock()
def connect(self):
return self.engine.connect()
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
@asynccontextmanager
async def connect(self):
await self.lock.acquire()
try:
return getattr(g, f"{self.db_name}_conn", None), getattr(
g, f"{self.db_name}_txn", None
)
except RuntimeError:
return None, None
async def begin(self):
conn, _ = self.session_connection()
if conn:
return
conn = await self.engine.connect()
setattr(g, f"{self.db_name}_conn", conn)
txn = await conn.begin()
setattr(g, f"{self.db_name}_txn", txn)
async with self.engine.connect() as conn:
async with conn.begin():
yield Connection(conn)
finally:
self.lock.release()
async def fetchall(self, query: str, values: tuple = ()) -> list:
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
return await result.fetchall()
async with self.connect() as conn:
result = await conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async with self.connect() as conn:
result = await conn.execute(query, values)
row = await result.fetchone()
@@ -59,29 +56,9 @@ class Database:
return row
async def execute(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
return await conn.execute(query, values)
async with self.connect() as conn:
return await conn.execute(query, values)
async def commit(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.commit()
await self.close_session()
async def rollback(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.rollback()
await self.close_session()
async def close_session(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.close()
await conn.close()
delattr(g, f"{self.db_name}_conn")
delattr(g, f"{self.db_name}_txn")
@asynccontextmanager
async def reuse_conn(self, conn: Connection):
yield conn

View File

@@ -1,6 +1,6 @@
import json
import math
from quart import g, jsonify, request
from quart import jsonify, request
from http import HTTPStatus
import traceback
@@ -29,7 +29,6 @@ from .helpers import (
# Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs.
@bleskomat_ext.route("/u", methods=["GET"])
async def api_bleskomat_lnurl():
try:
query = request.args.to_dict()
@@ -125,7 +124,7 @@ async def api_bleskomat_lnurl():
except LnurlHttpError as e:
return jsonify({"status": "ERROR", "reason": str(e)}), e.http_status
except Exception as e:
except Exception:
traceback.print_exc()
return (
jsonify({"status": "ERROR", "reason": "Unexpected error"}),

View File

@@ -61,7 +61,7 @@ class BleskomatLnurl(NamedTuple):
raise LnurlValidationError("Multiple payment requests not supported")
try:
invoice = bolt11.decode(pr)
except ValueError as e:
except ValueError:
raise LnurlValidationError(
'Invalid parameter ("pr"): Lightning payment request expected'
)
@@ -79,14 +79,11 @@ class BleskomatLnurl(NamedTuple):
async def execute_action(self, query: Dict[str, str]):
self.validate_action(query)
used = False
if self.initial_uses > 0:
await db.commit()
await db.begin()
used = await self.use()
if not used:
await db.rollback()
raise LnurlValidationError("Maximum number of uses already reached")
try:
async with db.connect() as conn:
if self.initial_uses > 0:
used = await self.use(conn)
if not used:
raise LnurlValidationError("Maximum number of uses already reached")
tag = self.tag
if tag == "withdrawRequest":
payment_hash = await pay_invoice(
@@ -95,16 +92,10 @@ class BleskomatLnurl(NamedTuple):
)
if not payment_hash:
raise LnurlValidationError("Failed to pay invoice")
except Exception as e:
if used:
await db.rollback()
raise e
if used:
await db.commit()
async def use(self) -> bool:
async def use(self, conn) -> bool:
now = int(time.time())
result = await db.execute(
result = await conn.execute(
"""
UPDATE bleskomat_lnurls
SET remaining_uses = remaining_uses - 1, updated_time = ?

View File

@@ -70,11 +70,10 @@ async def api_bleskomat_retrieve(bleskomat_id):
}
)
async def api_bleskomat_create_or_update(bleskomat_id=None):
try:
fiat_currency = g.data["fiat_currency"]
exchange_rate_provider = g.data["exchange_rate_provider"]
rate = await fetch_fiat_exchange_rate(
await fetch_fiat_exchange_rate(
currency=fiat_currency, provider=exchange_rate_provider
)
except Exception as e: