fix sqlite database locked issues by using an async lock on the database and requiring explicit transaction control (or each command will be its own transaction).

This commit is contained in:
fiatjaf
2021-03-26 19:10:30 -03:00
parent 9cc7052920
commit 85011d23c3
10 changed files with 276 additions and 276 deletions

View File

@@ -79,10 +79,6 @@ def register_blueprints(app: QuartTrio) -> None:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
bp = getattr(ext_module, f"{ext.code}_ext") bp = getattr(ext_module, f"{ext.code}_ext")
@bp.teardown_request
async def after_request(exc):
await ext_module.db.close_session()
app.register_blueprint(bp, url_prefix=f"/{ext.code}") app.register_blueprint(bp, url_prefix=f"/{ext.code}")
except Exception: except Exception:
raise ImportError( raise ImportError(
@@ -122,12 +118,6 @@ def register_request_hooks(app: QuartTrio):
async def before_request(): async def before_request():
g.nursery = app.nursery g.nursery = app.nursery
@app.teardown_request
async def after_request(exc):
from lnbits.core import db
await db.close_session()
@app.after_request @app.after_request
async def set_secure_headers(response): async def set_secure_headers(response):
secure_headers.quart(response) secure_headers.quart(response)

View File

@@ -53,15 +53,13 @@ def bundle_vendored():
async def migrate_databases(): async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them.""" """Creates the necessary databases if they don't exist already; or migrates them."""
core_conn = await core_db.connect() async with core_db.connect() as conn:
core_txn = await core_conn.begin()
try: try:
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
except OperationalError: except OperationalError:
# migration 3 wasn't ran # migration 3 wasn't ran
await core_migrations.m000_create_migrations_table(core_conn) await core_migrations.m000_create_migrations_table(conn)
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall() rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
current_versions = {row["db"]: row["version"] for row in rows} current_versions = {row["db"]: row["version"] for row in rows}
matcher = re.compile(r"^m(\d\d\d)_") matcher = re.compile(r"^m(\d\d\d)_")
@@ -75,12 +73,12 @@ async def migrate_databases():
if version > current_versions.get(db_name, 0): if version > current_versions.get(db_name, 0):
print(f"running migration {db_name}.{version}") print(f"running migration {db_name}.{version}")
await migrate(db) await migrate(db)
await core_conn.execute( await conn.execute(
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)", "INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)",
(db_name, version), (db_name, version),
) )
await run_migration(core_conn, core_migrations) await run_migration(conn, core_migrations)
for ext in get_valid_extensions(): for ext in get_valid_extensions():
try: try:
@@ -93,6 +91,3 @@ async def migrate_databases():
raise ImportError( raise ImportError(
f"Please make sure that the extension `{ext.code}` has a migrations file." f"Please make sure that the extension `{ext.code}` has a migrations file."
) )
await core_txn.commit()
await core_conn.close()

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from lnbits import bolt11 from lnbits import bolt11
from lnbits.db import Connection
from lnbits.settings import DEFAULT_WALLET_NAME from lnbits.settings import DEFAULT_WALLET_NAME
from . import db from . import db
@@ -14,32 +15,36 @@ from .models import User, Wallet, Payment
# -------- # --------
async def create_account() -> User: async def create_account(conn: Optional[Connection] = None) -> User:
user_id = uuid4().hex user_id = uuid4().hex
await db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,)) await (conn or db).execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
new_account = await get_account(user_id=user_id) new_account = await get_account(user_id=user_id, conn=conn)
assert new_account, "Newly created account couldn't be retrieved" assert new_account, "Newly created account couldn't be retrieved"
return new_account return new_account
async def get_account(user_id: str) -> Optional[User]: async def get_account(
row = await db.fetchone( user_id: str, conn: Optional[Connection] = None
) -> Optional[User]:
row = await (conn or db).fetchone(
"SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,) "SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)
) )
return User(**row) if row else None return User(**row) if row else None
async def get_user(user_id: str) -> Optional[User]: async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
user = await db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,)) user = await (conn or db).fetchone(
"SELECT id, email FROM accounts WHERE id = ?", (user_id,)
)
if user: if user:
extensions = await db.fetchall( extensions = await (conn or db).fetchall(
"SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,) "SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)
) )
wallets = await db.fetchall( wallets = await (conn or db).fetchall(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@@ -63,8 +68,10 @@ async def get_user(user_id: str) -> Optional[User]:
) )
async def update_user_extension(*, user_id: str, extension: str, active: int) -> None: async def update_user_extension(
await db.execute( *, user_id: str, extension: str, active: int, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
""" """
INSERT OR REPLACE INTO extensions (user, extension, active) INSERT OR REPLACE INTO extensions (user, extension, active)
VALUES (?, ?, ?) VALUES (?, ?, ?)
@@ -77,9 +84,14 @@ async def update_user_extension(*, user_id: str, extension: str, active: int) ->
# ------- # -------
async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet: async def create_wallet(
*,
user_id: str,
wallet_name: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Wallet:
wallet_id = uuid4().hex wallet_id = uuid4().hex
await db.execute( await (conn or db).execute(
""" """
INSERT INTO wallets (id, name, user, adminkey, inkey) INSERT INTO wallets (id, name, user, adminkey, inkey)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
@@ -93,14 +105,16 @@ async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> W
), ),
) )
new_wallet = await get_wallet(wallet_id=wallet_id) new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
assert new_wallet, "Newly created wallet couldn't be retrieved" assert new_wallet, "Newly created wallet couldn't be retrieved"
return new_wallet return new_wallet
async def delete_wallet(*, user_id: str, wallet_id: str) -> None: async def delete_wallet(
await db.execute( *, user_id: str, wallet_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
""" """
UPDATE wallets AS w UPDATE wallets AS w
SET SET
@@ -113,8 +127,10 @@ async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
) )
async def get_wallet(wallet_id: str) -> Optional[Wallet]: async def get_wallet(
row = await db.fetchone( wallet_id: str, conn: Optional[Connection] = None
) -> Optional[Wallet]:
row = await (conn or db).fetchone(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@@ -126,8 +142,10 @@ async def get_wallet(wallet_id: str) -> Optional[Wallet]:
return Wallet(**row) if row else None return Wallet(**row) if row else None
async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]: async def get_wallet_for_key(
row = await db.fetchone( key: str, key_type: str = "invoice", conn: Optional[Connection] = None
) -> Optional[Wallet]:
row = await (conn or db).fetchone(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
FROM wallets FROM wallets
@@ -149,8 +167,10 @@ async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wa
# --------------- # ---------------
async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]: async def get_standalone_payment(
row = await db.fetchone( checking_id_or_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
row = await (conn or db).fetchone(
""" """
SELECT * SELECT *
FROM apipayments FROM apipayments
@@ -163,8 +183,10 @@ async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]:
return Payment.from_row(row) if row else None return Payment.from_row(row) if row else None
async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]: async def get_wallet_payment(
row = await db.fetchone( wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
row = await (conn or db).fetchone(
""" """
SELECT * SELECT *
FROM apipayments FROM apipayments
@@ -185,6 +207,7 @@ async def get_payments(
incoming: bool = False, incoming: bool = False,
since: Optional[int] = None, since: Optional[int] = None,
exclude_uncheckable: bool = False, exclude_uncheckable: bool = False,
conn: Optional[Connection] = None,
) -> List[Payment]: ) -> List[Payment]:
""" """
Filters payments to be returned by complete | pending | outgoing | incoming. Filters payments to be returned by complete | pending | outgoing | incoming.
@@ -227,7 +250,7 @@ async def get_payments(
if clause: if clause:
where = f"WHERE {' AND '.join(clause)}" where = f"WHERE {' AND '.join(clause)}"
rows = await db.fetchall( rows = await (conn or db).fetchall(
f""" f"""
SELECT * SELECT *
FROM apipayments FROM apipayments
@@ -240,8 +263,10 @@ async def get_payments(
return [Payment.from_row(row) for row in rows] return [Payment.from_row(row) for row in rows]
async def delete_expired_invoices() -> None: async def delete_expired_invoices(
rows = await db.fetchall( conn: Optional[Connection] = None,
) -> None:
rows = await (conn or db).fetchall(
""" """
SELECT bolt11 SELECT bolt11
FROM apipayments FROM apipayments
@@ -258,7 +283,7 @@ async def delete_expired_invoices() -> None:
if expiration_date > datetime.datetime.utcnow(): if expiration_date > datetime.datetime.utcnow():
continue continue
await db.execute( await (conn or db).execute(
""" """
DELETE FROM apipayments DELETE FROM apipayments
WHERE pending = 1 AND hash = ? WHERE pending = 1 AND hash = ?
@@ -284,8 +309,9 @@ async def create_payment(
pending: bool = True, pending: bool = True,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
webhook: Optional[str] = None, webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Payment: ) -> Payment:
await db.execute( await (conn or db).execute(
""" """
INSERT INTO apipayments INSERT INTO apipayments
(wallet, checking_id, bolt11, hash, preimage, (wallet, checking_id, bolt11, hash, preimage,
@@ -309,14 +335,18 @@ async def create_payment(
), ),
) )
new_payment = await get_wallet_payment(wallet_id, payment_hash) new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
assert new_payment, "Newly created payment couldn't be retrieved" assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment return new_payment
async def update_payment_status(checking_id: str, pending: bool) -> None: async def update_payment_status(
await db.execute( checking_id: str,
pending: bool,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"UPDATE apipayments SET pending = ? WHERE checking_id = ?", "UPDATE apipayments SET pending = ? WHERE checking_id = ?",
( (
int(pending), int(pending),
@@ -325,12 +355,20 @@ async def update_payment_status(checking_id: str, pending: bool) -> None:
) )
async def delete_payment(checking_id: str) -> None: async def delete_payment(
await db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) checking_id: str,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).execute(
"DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)
)
async def check_internal(payment_hash: str) -> Optional[str]: async def check_internal(
row = await db.fetchone( payment_hash: str,
conn: Optional[Connection] = None,
) -> Optional[str]:
row = await (conn or db).fetchone(
""" """
SELECT checking_id FROM apipayments SELECT checking_id FROM apipayments
WHERE hash = ? AND pending AND amount > 0 WHERE hash = ? AND pending AND amount > 0

View File

@@ -13,6 +13,7 @@ except ImportError: # pragma: nocover
from typing_extensions import TypedDict from typing_extensions import TypedDict
from lnbits import bolt11 from lnbits import bolt11
from lnbits.db import Connection
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from lnbits.settings import WALLET from lnbits.settings import WALLET
from lnbits.wallets.base import PaymentStatus, PaymentResponse from lnbits.wallets.base import PaymentStatus, PaymentResponse
@@ -36,8 +37,8 @@ async def create_invoice(
description_hash: Optional[bytes] = None, description_hash: Optional[bytes] = None,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
webhook: Optional[str] = None, webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
await db.begin()
invoice_memo = None if description_hash else memo invoice_memo = None if description_hash else memo
storeable_memo = memo storeable_memo = memo
@@ -59,9 +60,9 @@ async def create_invoice(
memo=storeable_memo, memo=storeable_memo,
extra=extra, extra=extra,
webhook=webhook, webhook=webhook,
conn=conn,
) )
await db.commit()
return invoice.payment_hash, payment_request return invoice.payment_hash, payment_request
@@ -72,8 +73,9 @@ async def pay_invoice(
max_sat: Optional[int] = None, max_sat: Optional[int] = None,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
description: str = "", description: str = "",
conn: Optional[Connection] = None,
) -> str: ) -> str:
await db.begin() async with (db.reuse_conn(conn) if conn else db.connect()) as conn:
temp_id = f"temp_{urlsafe_short_hash()}" temp_id = f"temp_{urlsafe_short_hash()}"
internal_id = f"internal_{urlsafe_short_hash()}" internal_id = f"internal_{urlsafe_short_hash()}"
@@ -105,33 +107,42 @@ async def pay_invoice(
) )
# check_internal() returns the checking_id of the invoice we're waiting for # check_internal() returns the checking_id of the invoice we're waiting for
internal_checking_id = await check_internal(invoice.payment_hash) internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
if internal_checking_id: if internal_checking_id:
# create a new payment from this wallet # create a new payment from this wallet
await create_payment( await create_payment(
checking_id=internal_id, fee=0, pending=False, **payment_kwargs checking_id=internal_id,
fee=0,
pending=False,
conn=conn,
**payment_kwargs,
) )
else: else:
# create a temporary payment here so we can check if # create a temporary payment here so we can check if
# the balance is enough in the next step # the balance is enough in the next step
fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) await create_payment(
checking_id=temp_id,
fee=-fee_reserve,
conn=conn,
**payment_kwargs,
)
# do the balance check # do the balance check
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id, conn=conn)
assert wallet assert wallet
if wallet.balance_msat < 0: if wallet.balance_msat < 0:
await db.rollback()
raise PermissionError("Insufficient balance.") raise PermissionError("Insufficient balance.")
else:
await db.commit()
await db.begin()
if internal_checking_id: if internal_checking_id:
# mark the invoice from the other side as not pending anymore # mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure # so the other side only has access to his new money when we are sure
# the payer has enough to deduct from # the payer has enough to deduct from
await update_payment_status(checking_id=internal_checking_id, pending=False) await update_payment_status(
checking_id=internal_checking_id,
pending=False,
conn=conn,
)
# notify receiver asynchronously # notify receiver asynchronously
from lnbits.tasks import internal_invoice_paid from lnbits.tasks import internal_invoice_paid
@@ -146,13 +157,11 @@ async def pay_invoice(
fee=payment.fee_msat, fee=payment.fee_msat,
preimage=payment.preimage, preimage=payment.preimage,
pending=payment.ok == None, pending=payment.ok == None,
conn=conn,
**payment_kwargs, **payment_kwargs,
) )
await delete_payment(temp_id) await delete_payment(temp_id, conn=conn)
await db.commit()
else: else:
await delete_payment(temp_id)
await db.commit()
raise Exception( raise Exception(
payment.error_message or "Failed to pay_invoice on backend." payment.error_message or "Failed to pay_invoice on backend."
) )
@@ -161,13 +170,17 @@ async def pay_invoice(
async def redeem_lnurl_withdraw( async def redeem_lnurl_withdraw(
wallet_id: str, res: LnurlWithdrawResponse, memo: Optional[str] = None wallet_id: str,
res: LnurlWithdrawResponse,
memo: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None: ) -> None:
_, payment_request = await create_invoice( _, payment_request = await create_invoice(
wallet_id=wallet_id, wallet_id=wallet_id,
amount=res.max_sats, amount=res.max_sats,
memo=memo or res.default_description or "", memo=memo or res.default_description or "",
extra={"tag": "lnurlwallet"}, extra={"tag": "lnurlwallet"},
conn=conn,
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -180,7 +193,10 @@ async def redeem_lnurl_withdraw(
) )
async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]: async def perform_lnurlauth(
callback: str,
conn: Optional[Connection] = None,
) -> Optional[LnurlErrorResponse]:
cb = urlparse(callback) cb = urlparse(callback)
k1 = unhexlify(parse_qs(cb.query)["k1"][0]) k1 = unhexlify(parse_qs(cb.query)["k1"][0])
@@ -251,8 +267,12 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]:
) )
async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: async def check_invoice_status(
payment = await get_wallet_payment(wallet_id, payment_hash) wallet_id: str,
payment_hash: str,
conn: Optional[Connection] = None,
) -> PaymentStatus:
payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
if not payment: if not payment:
return PaymentStatus(None) return PaymentStatus(None)

View File

@@ -66,7 +66,7 @@ async def api_payments_create_invoice():
description_hash = b"" description_hash = b""
memo = g.data["memo"] memo = g.data["memo"]
try: async with db.connect() as conn:
payment_hash, payment_request = await create_invoice( payment_hash, payment_request = await create_invoice(
wallet_id=g.wallet.id, wallet_id=g.wallet.id,
amount=g.data["amount"], amount=g.data["amount"],
@@ -74,12 +74,8 @@ async def api_payments_create_invoice():
description_hash=description_hash, description_hash=description_hash,
extra=g.data.get("extra"), extra=g.data.get("extra"),
webhook=g.data.get("webhook"), webhook=g.data.get("webhook"),
conn=conn,
) )
except Exception as exc:
await db.rollback()
raise exc
await db.commit()
invoice = bolt11.decode(payment_request) invoice = bolt11.decode(payment_request)
@@ -124,14 +120,14 @@ async def api_payments_create_invoice():
async def api_payments_pay_invoice(): async def api_payments_pay_invoice():
try: try:
payment_hash = await pay_invoice( payment_hash = await pay_invoice(
wallet_id=g.wallet.id, payment_request=g.data["bolt11"] wallet_id=g.wallet.id,
payment_request=g.data["bolt11"],
) )
except ValueError as e: except ValueError as e:
return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST
except PermissionError as e: except PermissionError as e:
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
except Exception as exc: except Exception as exc:
await db.rollback()
raise exc raise exc
return ( return (
@@ -217,7 +213,6 @@ async def api_payments_pay_lnurl():
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
) )
try:
extra = {} extra = {}
if params.get("successAction"): if params.get("successAction"):
@@ -231,9 +226,6 @@ async def api_payments_pay_lnurl():
description=g.data.get("description", ""), description=g.data.get("description", ""),
extra=extra, extra=extra,
) )
except Exception as exc:
await db.rollback()
raise exc
return ( return (
jsonify( jsonify(

View File

@@ -13,11 +13,10 @@ from quart import (
) )
from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl # type: ignore from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl # type: ignore
from lnbits.core import core_app from lnbits.core import core_app, db
from lnbits.decorators import check_user_exists, validate_uuids from lnbits.decorators import check_user_exists, validate_uuids
from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE
from .. import db
from ..crud import ( from ..crud import (
create_account, create_account,
get_user, get_user,
@@ -152,10 +151,10 @@ async def lnurlwallet():
HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR,
) )
account = await create_account() async with db.connect() as conn:
user = await get_user(account.id) account = await create_account(conn=conn)
wallet = await create_wallet(user_id=user.id) user = await get_user(account.id, conn=conn)
await db.commit() wallet = await create_wallet(user_id=user.id, conn=conn)
g.nursery.start_soon( g.nursery.start_soon(
redeem_lnurl_withdraw, redeem_lnurl_withdraw,

View File

@@ -1,57 +1,54 @@
import os import os
from typing import Tuple, Optional, Any import trio
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore from contextlib import asynccontextmanager
from sqlalchemy import create_engine # type: ignore from sqlalchemy import create_engine # type: ignore
from quart import g from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
from sqlalchemy_aio.base import AsyncConnection
from .settings import LNBITS_DATA_FOLDER from .settings import LNBITS_DATA_FOLDER
class Connection:
def __init__(self, conn: AsyncConnection):
self.conn = conn
async def fetchall(self, query: str, values: tuple = ()) -> list:
result = await self.conn.execute(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()):
result = await self.conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async def execute(self, query: str, values: tuple = ()):
return await self.conn.execute(query, values)
class Database: class Database:
def __init__(self, db_name: str): def __init__(self, db_name: str):
self.db_name = db_name self.db_name = db_name
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3") db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY) self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
self.lock = trio.StrictFIFOLock()
def connect(self): @asynccontextmanager
return self.engine.connect() async def connect(self):
await self.lock.acquire()
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
try: try:
return getattr(g, f"{self.db_name}_conn", None), getattr( async with self.engine.connect() as conn:
g, f"{self.db_name}_txn", None async with conn.begin():
) yield Connection(conn)
except RuntimeError: finally:
return None, None self.lock.release()
async def begin(self):
conn, _ = self.session_connection()
if conn:
return
conn = await self.engine.connect()
setattr(g, f"{self.db_name}_conn", conn)
txn = await conn.begin()
setattr(g, f"{self.db_name}_txn", txn)
async def fetchall(self, query: str, values: tuple = ()) -> list: async def fetchall(self, query: str, values: tuple = ()) -> list:
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
return await result.fetchall()
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) result = await conn.execute(query, values)
return await result.fetchall() return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()): async def fetchone(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
result = await conn.execute(query, values)
row = await result.fetchone()
await result.close()
return row
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) result = await conn.execute(query, values)
row = await result.fetchone() row = await result.fetchone()
@@ -59,29 +56,9 @@ class Database:
return row return row
async def execute(self, query: str, values: tuple = ()): async def execute(self, query: str, values: tuple = ()):
conn, _ = self.session_connection()
if conn:
return await conn.execute(query, values)
async with self.connect() as conn: async with self.connect() as conn:
return await conn.execute(query, values) return await conn.execute(query, values)
async def commit(self): @asynccontextmanager
conn, txn = self.session_connection() async def reuse_conn(self, conn: Connection):
if conn and txn: yield conn
await txn.commit()
await self.close_session()
async def rollback(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.rollback()
await self.close_session()
async def close_session(self):
conn, txn = self.session_connection()
if conn and txn:
await txn.close()
await conn.close()
delattr(g, f"{self.db_name}_conn")
delattr(g, f"{self.db_name}_txn")

View File

@@ -1,6 +1,6 @@
import json import json
import math import math
from quart import g, jsonify, request from quart import jsonify, request
from http import HTTPStatus from http import HTTPStatus
import traceback import traceback
@@ -29,7 +29,6 @@ from .helpers import (
# Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs. # Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs.
@bleskomat_ext.route("/u", methods=["GET"]) @bleskomat_ext.route("/u", methods=["GET"])
async def api_bleskomat_lnurl(): async def api_bleskomat_lnurl():
try: try:
query = request.args.to_dict() query = request.args.to_dict()
@@ -125,7 +124,7 @@ async def api_bleskomat_lnurl():
except LnurlHttpError as e: except LnurlHttpError as e:
return jsonify({"status": "ERROR", "reason": str(e)}), e.http_status return jsonify({"status": "ERROR", "reason": str(e)}), e.http_status
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
return ( return (
jsonify({"status": "ERROR", "reason": "Unexpected error"}), jsonify({"status": "ERROR", "reason": "Unexpected error"}),

View File

@@ -61,7 +61,7 @@ class BleskomatLnurl(NamedTuple):
raise LnurlValidationError("Multiple payment requests not supported") raise LnurlValidationError("Multiple payment requests not supported")
try: try:
invoice = bolt11.decode(pr) invoice = bolt11.decode(pr)
except ValueError as e: except ValueError:
raise LnurlValidationError( raise LnurlValidationError(
'Invalid parameter ("pr"): Lightning payment request expected' 'Invalid parameter ("pr"): Lightning payment request expected'
) )
@@ -79,14 +79,11 @@ class BleskomatLnurl(NamedTuple):
async def execute_action(self, query: Dict[str, str]): async def execute_action(self, query: Dict[str, str]):
self.validate_action(query) self.validate_action(query)
used = False used = False
async with db.connect() as conn:
if self.initial_uses > 0: if self.initial_uses > 0:
await db.commit() used = await self.use(conn)
await db.begin()
used = await self.use()
if not used: if not used:
await db.rollback()
raise LnurlValidationError("Maximum number of uses already reached") raise LnurlValidationError("Maximum number of uses already reached")
try:
tag = self.tag tag = self.tag
if tag == "withdrawRequest": if tag == "withdrawRequest":
payment_hash = await pay_invoice( payment_hash = await pay_invoice(
@@ -95,16 +92,10 @@ class BleskomatLnurl(NamedTuple):
) )
if not payment_hash: if not payment_hash:
raise LnurlValidationError("Failed to pay invoice") raise LnurlValidationError("Failed to pay invoice")
except Exception as e:
if used:
await db.rollback()
raise e
if used:
await db.commit()
async def use(self) -> bool: async def use(self, conn) -> bool:
now = int(time.time()) now = int(time.time())
result = await db.execute( result = await conn.execute(
""" """
UPDATE bleskomat_lnurls UPDATE bleskomat_lnurls
SET remaining_uses = remaining_uses - 1, updated_time = ? SET remaining_uses = remaining_uses - 1, updated_time = ?

View File

@@ -70,11 +70,10 @@ async def api_bleskomat_retrieve(bleskomat_id):
} }
) )
async def api_bleskomat_create_or_update(bleskomat_id=None): async def api_bleskomat_create_or_update(bleskomat_id=None):
try: try:
fiat_currency = g.data["fiat_currency"] fiat_currency = g.data["fiat_currency"]
exchange_rate_provider = g.data["exchange_rate_provider"] exchange_rate_provider = g.data["exchange_rate_provider"]
rate = await fetch_fiat_exchange_rate( await fetch_fiat_exchange_rate(
currency=fiat_currency, provider=exchange_rate_provider currency=fiat_currency, provider=exchange_rate_provider
) )
except Exception as e: except Exception as e: