mirror of
https://github.com/lnbits/lnbits.git
synced 2025-09-25 19:36:15 +02:00
fix sqlite database locked issues by using an async lock on the database and requiring explicit transaction control (or each command will be its own transaction).
This commit is contained in:
@@ -79,10 +79,6 @@ def register_blueprints(app: QuartTrio) -> None:
|
||||
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
|
||||
bp = getattr(ext_module, f"{ext.code}_ext")
|
||||
|
||||
@bp.teardown_request
|
||||
async def after_request(exc):
|
||||
await ext_module.db.close_session()
|
||||
|
||||
app.register_blueprint(bp, url_prefix=f"/{ext.code}")
|
||||
except Exception:
|
||||
raise ImportError(
|
||||
@@ -122,12 +118,6 @@ def register_request_hooks(app: QuartTrio):
|
||||
async def before_request():
|
||||
g.nursery = app.nursery
|
||||
|
||||
@app.teardown_request
|
||||
async def after_request(exc):
|
||||
from lnbits.core import db
|
||||
|
||||
await db.close_session()
|
||||
|
||||
@app.after_request
|
||||
async def set_secure_headers(response):
|
||||
secure_headers.quart(response)
|
||||
|
@@ -53,46 +53,41 @@ def bundle_vendored():
|
||||
async def migrate_databases():
|
||||
"""Creates the necessary databases if they don't exist already; or migrates them."""
|
||||
|
||||
core_conn = await core_db.connect()
|
||||
core_txn = await core_conn.begin()
|
||||
|
||||
try:
|
||||
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall()
|
||||
except OperationalError:
|
||||
# migration 3 wasn't ran
|
||||
await core_migrations.m000_create_migrations_table(core_conn)
|
||||
rows = await (await core_conn.execute("SELECT * FROM dbversions")).fetchall()
|
||||
|
||||
current_versions = {row["db"]: row["version"] for row in rows}
|
||||
matcher = re.compile(r"^m(\d\d\d)_")
|
||||
|
||||
async def run_migration(db, migrations_module):
|
||||
db_name = migrations_module.__name__.split(".")[-2]
|
||||
for key, migrate in migrations_module.__dict__.items():
|
||||
match = match = matcher.match(key)
|
||||
if match:
|
||||
version = int(match.group(1))
|
||||
if version > current_versions.get(db_name, 0):
|
||||
print(f"running migration {db_name}.{version}")
|
||||
await migrate(db)
|
||||
await core_conn.execute(
|
||||
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)",
|
||||
(db_name, version),
|
||||
)
|
||||
|
||||
await run_migration(core_conn, core_migrations)
|
||||
|
||||
for ext in get_valid_extensions():
|
||||
async with core_db.connect() as conn:
|
||||
try:
|
||||
ext_migrations = importlib.import_module(
|
||||
f"lnbits.extensions.{ext.code}.migrations"
|
||||
)
|
||||
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
|
||||
await run_migration(ext_db, ext_migrations)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Please make sure that the extension `{ext.code}` has a migrations file."
|
||||
)
|
||||
rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
|
||||
except OperationalError:
|
||||
# migration 3 wasn't ran
|
||||
await core_migrations.m000_create_migrations_table(conn)
|
||||
rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
|
||||
|
||||
await core_txn.commit()
|
||||
await core_conn.close()
|
||||
current_versions = {row["db"]: row["version"] for row in rows}
|
||||
matcher = re.compile(r"^m(\d\d\d)_")
|
||||
|
||||
async def run_migration(db, migrations_module):
|
||||
db_name = migrations_module.__name__.split(".")[-2]
|
||||
for key, migrate in migrations_module.__dict__.items():
|
||||
match = match = matcher.match(key)
|
||||
if match:
|
||||
version = int(match.group(1))
|
||||
if version > current_versions.get(db_name, 0):
|
||||
print(f"running migration {db_name}.{version}")
|
||||
await migrate(db)
|
||||
await conn.execute(
|
||||
"INSERT OR REPLACE INTO dbversions (db, version) VALUES (?, ?)",
|
||||
(db_name, version),
|
||||
)
|
||||
|
||||
await run_migration(conn, core_migrations)
|
||||
|
||||
for ext in get_valid_extensions():
|
||||
try:
|
||||
ext_migrations = importlib.import_module(
|
||||
f"lnbits.extensions.{ext.code}.migrations"
|
||||
)
|
||||
ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
|
||||
await run_migration(ext_db, ext_migrations)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Please make sure that the extension `{ext.code}` has a migrations file."
|
||||
)
|
||||
|
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.db import Connection
|
||||
from lnbits.settings import DEFAULT_WALLET_NAME
|
||||
|
||||
from . import db
|
||||
@@ -14,32 +15,36 @@ from .models import User, Wallet, Payment
|
||||
# --------
|
||||
|
||||
|
||||
async def create_account() -> User:
|
||||
async def create_account(conn: Optional[Connection] = None) -> User:
|
||||
user_id = uuid4().hex
|
||||
await db.execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
|
||||
await (conn or db).execute("INSERT INTO accounts (id) VALUES (?)", (user_id,))
|
||||
|
||||
new_account = await get_account(user_id=user_id)
|
||||
new_account = await get_account(user_id=user_id, conn=conn)
|
||||
assert new_account, "Newly created account couldn't be retrieved"
|
||||
|
||||
return new_account
|
||||
|
||||
|
||||
async def get_account(user_id: str) -> Optional[User]:
|
||||
row = await db.fetchone(
|
||||
async def get_account(
|
||||
user_id: str, conn: Optional[Connection] = None
|
||||
) -> Optional[User]:
|
||||
row = await (conn or db).fetchone(
|
||||
"SELECT id, email, pass as password FROM accounts WHERE id = ?", (user_id,)
|
||||
)
|
||||
|
||||
return User(**row) if row else None
|
||||
|
||||
|
||||
async def get_user(user_id: str) -> Optional[User]:
|
||||
user = await db.fetchone("SELECT id, email FROM accounts WHERE id = ?", (user_id,))
|
||||
async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
|
||||
user = await (conn or db).fetchone(
|
||||
"SELECT id, email FROM accounts WHERE id = ?", (user_id,)
|
||||
)
|
||||
|
||||
if user:
|
||||
extensions = await db.fetchall(
|
||||
extensions = await (conn or db).fetchall(
|
||||
"SELECT extension FROM extensions WHERE user = ? AND active = 1", (user_id,)
|
||||
)
|
||||
wallets = await db.fetchall(
|
||||
wallets = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
|
||||
FROM wallets
|
||||
@@ -63,8 +68,10 @@ async def get_user(user_id: str) -> Optional[User]:
|
||||
)
|
||||
|
||||
|
||||
async def update_user_extension(*, user_id: str, extension: str, active: int) -> None:
|
||||
await db.execute(
|
||||
async def update_user_extension(
|
||||
*, user_id: str, extension: str, active: int, conn: Optional[Connection] = None
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO extensions (user, extension, active)
|
||||
VALUES (?, ?, ?)
|
||||
@@ -77,9 +84,14 @@ async def update_user_extension(*, user_id: str, extension: str, active: int) ->
|
||||
# -------
|
||||
|
||||
|
||||
async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> Wallet:
|
||||
async def create_wallet(
|
||||
*,
|
||||
user_id: str,
|
||||
wallet_name: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Wallet:
|
||||
wallet_id = uuid4().hex
|
||||
await db.execute(
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
INSERT INTO wallets (id, name, user, adminkey, inkey)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
@@ -93,14 +105,16 @@ async def create_wallet(*, user_id: str, wallet_name: Optional[str] = None) -> W
|
||||
),
|
||||
)
|
||||
|
||||
new_wallet = await get_wallet(wallet_id=wallet_id)
|
||||
new_wallet = await get_wallet(wallet_id=wallet_id, conn=conn)
|
||||
assert new_wallet, "Newly created wallet couldn't be retrieved"
|
||||
|
||||
return new_wallet
|
||||
|
||||
|
||||
async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
|
||||
await db.execute(
|
||||
async def delete_wallet(
|
||||
*, user_id: str, wallet_id: str, conn: Optional[Connection] = None
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
UPDATE wallets AS w
|
||||
SET
|
||||
@@ -113,8 +127,10 @@ async def delete_wallet(*, user_id: str, wallet_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def get_wallet(wallet_id: str) -> Optional[Wallet]:
|
||||
row = await db.fetchone(
|
||||
async def get_wallet(
|
||||
wallet_id: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Wallet]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
|
||||
FROM wallets
|
||||
@@ -126,8 +142,10 @@ async def get_wallet(wallet_id: str) -> Optional[Wallet]:
|
||||
return Wallet(**row) if row else None
|
||||
|
||||
|
||||
async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wallet]:
|
||||
row = await db.fetchone(
|
||||
async def get_wallet_for_key(
|
||||
key: str, key_type: str = "invoice", conn: Optional[Connection] = None
|
||||
) -> Optional[Wallet]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) AS balance_msat
|
||||
FROM wallets
|
||||
@@ -149,8 +167,10 @@ async def get_wallet_for_key(key: str, key_type: str = "invoice") -> Optional[Wa
|
||||
# ---------------
|
||||
|
||||
|
||||
async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]:
|
||||
row = await db.fetchone(
|
||||
async def get_standalone_payment(
|
||||
checking_id_or_hash: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Payment]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *
|
||||
FROM apipayments
|
||||
@@ -163,8 +183,10 @@ async def get_standalone_payment(checking_id_or_hash: str) -> Optional[Payment]:
|
||||
return Payment.from_row(row) if row else None
|
||||
|
||||
|
||||
async def get_wallet_payment(wallet_id: str, payment_hash: str) -> Optional[Payment]:
|
||||
row = await db.fetchone(
|
||||
async def get_wallet_payment(
|
||||
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
|
||||
) -> Optional[Payment]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT *
|
||||
FROM apipayments
|
||||
@@ -185,6 +207,7 @@ async def get_payments(
|
||||
incoming: bool = False,
|
||||
since: Optional[int] = None,
|
||||
exclude_uncheckable: bool = False,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[Payment]:
|
||||
"""
|
||||
Filters payments to be returned by complete | pending | outgoing | incoming.
|
||||
@@ -227,7 +250,7 @@ async def get_payments(
|
||||
if clause:
|
||||
where = f"WHERE {' AND '.join(clause)}"
|
||||
|
||||
rows = await db.fetchall(
|
||||
rows = await (conn or db).fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM apipayments
|
||||
@@ -240,8 +263,10 @@ async def get_payments(
|
||||
return [Payment.from_row(row) for row in rows]
|
||||
|
||||
|
||||
async def delete_expired_invoices() -> None:
|
||||
rows = await db.fetchall(
|
||||
async def delete_expired_invoices(
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
rows = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT bolt11
|
||||
FROM apipayments
|
||||
@@ -258,7 +283,7 @@ async def delete_expired_invoices() -> None:
|
||||
if expiration_date > datetime.datetime.utcnow():
|
||||
continue
|
||||
|
||||
await db.execute(
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
DELETE FROM apipayments
|
||||
WHERE pending = 1 AND hash = ?
|
||||
@@ -284,8 +309,9 @@ async def create_payment(
|
||||
pending: bool = True,
|
||||
extra: Optional[Dict] = None,
|
||||
webhook: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Payment:
|
||||
await db.execute(
|
||||
await (conn or db).execute(
|
||||
"""
|
||||
INSERT INTO apipayments
|
||||
(wallet, checking_id, bolt11, hash, preimage,
|
||||
@@ -309,14 +335,18 @@ async def create_payment(
|
||||
),
|
||||
)
|
||||
|
||||
new_payment = await get_wallet_payment(wallet_id, payment_hash)
|
||||
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
assert new_payment, "Newly created payment couldn't be retrieved"
|
||||
|
||||
return new_payment
|
||||
|
||||
|
||||
async def update_payment_status(checking_id: str, pending: bool) -> None:
|
||||
await db.execute(
|
||||
async def update_payment_status(
|
||||
checking_id: str,
|
||||
pending: bool,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"UPDATE apipayments SET pending = ? WHERE checking_id = ?",
|
||||
(
|
||||
int(pending),
|
||||
@@ -325,12 +355,20 @@ async def update_payment_status(checking_id: str, pending: bool) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def delete_payment(checking_id: str) -> None:
|
||||
await db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
|
||||
async def delete_payment(
|
||||
checking_id: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
await (conn or db).execute(
|
||||
"DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)
|
||||
)
|
||||
|
||||
|
||||
async def check_internal(payment_hash: str) -> Optional[str]:
|
||||
row = await db.fetchone(
|
||||
async def check_internal(
|
||||
payment_hash: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[str]:
|
||||
row = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT checking_id FROM apipayments
|
||||
WHERE hash = ? AND pending AND amount > 0
|
||||
|
@@ -13,6 +13,7 @@ except ImportError: # pragma: nocover
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from lnbits import bolt11
|
||||
from lnbits.db import Connection
|
||||
from lnbits.helpers import urlsafe_short_hash
|
||||
from lnbits.settings import WALLET
|
||||
from lnbits.wallets.base import PaymentStatus, PaymentResponse
|
||||
@@ -36,8 +37,8 @@ async def create_invoice(
|
||||
description_hash: Optional[bytes] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
webhook: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Tuple[str, str]:
|
||||
await db.begin()
|
||||
invoice_memo = None if description_hash else memo
|
||||
storeable_memo = memo
|
||||
|
||||
@@ -59,9 +60,9 @@ async def create_invoice(
|
||||
memo=storeable_memo,
|
||||
extra=extra,
|
||||
webhook=webhook,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
return invoice.payment_hash, payment_request
|
||||
|
||||
|
||||
@@ -72,102 +73,114 @@ async def pay_invoice(
|
||||
max_sat: Optional[int] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
description: str = "",
|
||||
conn: Optional[Connection] = None,
|
||||
) -> str:
|
||||
await db.begin()
|
||||
temp_id = f"temp_{urlsafe_short_hash()}"
|
||||
internal_id = f"internal_{urlsafe_short_hash()}"
|
||||
async with (db.reuse_conn(conn) if conn else db.connect()) as conn:
|
||||
temp_id = f"temp_{urlsafe_short_hash()}"
|
||||
internal_id = f"internal_{urlsafe_short_hash()}"
|
||||
|
||||
invoice = bolt11.decode(payment_request)
|
||||
if invoice.amount_msat == 0:
|
||||
raise ValueError("Amountless invoices not supported.")
|
||||
if max_sat and invoice.amount_msat > max_sat * 1000:
|
||||
raise ValueError("Amount in invoice is too high.")
|
||||
invoice = bolt11.decode(payment_request)
|
||||
if invoice.amount_msat == 0:
|
||||
raise ValueError("Amountless invoices not supported.")
|
||||
if max_sat and invoice.amount_msat > max_sat * 1000:
|
||||
raise ValueError("Amount in invoice is too high.")
|
||||
|
||||
# put all parameters that don't change here
|
||||
PaymentKwargs = TypedDict(
|
||||
"PaymentKwargs",
|
||||
{
|
||||
"wallet_id": str,
|
||||
"payment_request": str,
|
||||
"payment_hash": str,
|
||||
"amount": int,
|
||||
"memo": str,
|
||||
"extra": Optional[Dict],
|
||||
},
|
||||
)
|
||||
payment_kwargs: PaymentKwargs = dict(
|
||||
wallet_id=wallet_id,
|
||||
payment_request=payment_request,
|
||||
payment_hash=invoice.payment_hash,
|
||||
amount=-invoice.amount_msat,
|
||||
memo=description or invoice.description or "",
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
# check_internal() returns the checking_id of the invoice we're waiting for
|
||||
internal_checking_id = await check_internal(invoice.payment_hash)
|
||||
if internal_checking_id:
|
||||
# create a new payment from this wallet
|
||||
await create_payment(
|
||||
checking_id=internal_id, fee=0, pending=False, **payment_kwargs
|
||||
# put all parameters that don't change here
|
||||
PaymentKwargs = TypedDict(
|
||||
"PaymentKwargs",
|
||||
{
|
||||
"wallet_id": str,
|
||||
"payment_request": str,
|
||||
"payment_hash": str,
|
||||
"amount": int,
|
||||
"memo": str,
|
||||
"extra": Optional[Dict],
|
||||
},
|
||||
)
|
||||
payment_kwargs: PaymentKwargs = dict(
|
||||
wallet_id=wallet_id,
|
||||
payment_request=payment_request,
|
||||
payment_hash=invoice.payment_hash,
|
||||
amount=-invoice.amount_msat,
|
||||
memo=description or invoice.description or "",
|
||||
extra=extra,
|
||||
)
|
||||
else:
|
||||
# create a temporary payment here so we can check if
|
||||
# the balance is enough in the next step
|
||||
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
|
||||
await create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
|
||||
|
||||
# do the balance check
|
||||
wallet = await get_wallet(wallet_id)
|
||||
assert wallet
|
||||
if wallet.balance_msat < 0:
|
||||
await db.rollback()
|
||||
raise PermissionError("Insufficient balance.")
|
||||
else:
|
||||
await db.commit()
|
||||
await db.begin()
|
||||
|
||||
if internal_checking_id:
|
||||
# mark the invoice from the other side as not pending anymore
|
||||
# so the other side only has access to his new money when we are sure
|
||||
# the payer has enough to deduct from
|
||||
await update_payment_status(checking_id=internal_checking_id, pending=False)
|
||||
|
||||
# notify receiver asynchronously
|
||||
from lnbits.tasks import internal_invoice_paid
|
||||
|
||||
await internal_invoice_paid.send(internal_checking_id)
|
||||
else:
|
||||
# actually pay the external invoice
|
||||
payment: PaymentResponse = await WALLET.pay_invoice(payment_request)
|
||||
if payment.checking_id:
|
||||
# check_internal() returns the checking_id of the invoice we're waiting for
|
||||
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
|
||||
if internal_checking_id:
|
||||
# create a new payment from this wallet
|
||||
await create_payment(
|
||||
checking_id=payment.checking_id,
|
||||
fee=payment.fee_msat,
|
||||
preimage=payment.preimage,
|
||||
pending=payment.ok == None,
|
||||
checking_id=internal_id,
|
||||
fee=0,
|
||||
pending=False,
|
||||
conn=conn,
|
||||
**payment_kwargs,
|
||||
)
|
||||
await delete_payment(temp_id)
|
||||
await db.commit()
|
||||
else:
|
||||
await delete_payment(temp_id)
|
||||
await db.commit()
|
||||
raise Exception(
|
||||
payment.error_message or "Failed to pay_invoice on backend."
|
||||
# create a temporary payment here so we can check if
|
||||
# the balance is enough in the next step
|
||||
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
|
||||
await create_payment(
|
||||
checking_id=temp_id,
|
||||
fee=-fee_reserve,
|
||||
conn=conn,
|
||||
**payment_kwargs,
|
||||
)
|
||||
|
||||
return invoice.payment_hash
|
||||
# do the balance check
|
||||
wallet = await get_wallet(wallet_id, conn=conn)
|
||||
assert wallet
|
||||
if wallet.balance_msat < 0:
|
||||
raise PermissionError("Insufficient balance.")
|
||||
|
||||
if internal_checking_id:
|
||||
# mark the invoice from the other side as not pending anymore
|
||||
# so the other side only has access to his new money when we are sure
|
||||
# the payer has enough to deduct from
|
||||
await update_payment_status(
|
||||
checking_id=internal_checking_id,
|
||||
pending=False,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
# notify receiver asynchronously
|
||||
from lnbits.tasks import internal_invoice_paid
|
||||
|
||||
await internal_invoice_paid.send(internal_checking_id)
|
||||
else:
|
||||
# actually pay the external invoice
|
||||
payment: PaymentResponse = await WALLET.pay_invoice(payment_request)
|
||||
if payment.checking_id:
|
||||
await create_payment(
|
||||
checking_id=payment.checking_id,
|
||||
fee=payment.fee_msat,
|
||||
preimage=payment.preimage,
|
||||
pending=payment.ok == None,
|
||||
conn=conn,
|
||||
**payment_kwargs,
|
||||
)
|
||||
await delete_payment(temp_id, conn=conn)
|
||||
else:
|
||||
raise Exception(
|
||||
payment.error_message or "Failed to pay_invoice on backend."
|
||||
)
|
||||
|
||||
return invoice.payment_hash
|
||||
|
||||
|
||||
async def redeem_lnurl_withdraw(
|
||||
wallet_id: str, res: LnurlWithdrawResponse, memo: Optional[str] = None
|
||||
wallet_id: str,
|
||||
res: LnurlWithdrawResponse,
|
||||
memo: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
_, payment_request = await create_invoice(
|
||||
wallet_id=wallet_id,
|
||||
amount=res.max_sats,
|
||||
memo=memo or res.default_description or "",
|
||||
extra={"tag": "lnurlwallet"},
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -180,7 +193,10 @@ async def redeem_lnurl_withdraw(
|
||||
)
|
||||
|
||||
|
||||
async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]:
|
||||
async def perform_lnurlauth(
|
||||
callback: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Optional[LnurlErrorResponse]:
|
||||
cb = urlparse(callback)
|
||||
|
||||
k1 = unhexlify(parse_qs(cb.query)["k1"][0])
|
||||
@@ -251,8 +267,12 @@ async def perform_lnurlauth(callback: str) -> Optional[LnurlErrorResponse]:
|
||||
)
|
||||
|
||||
|
||||
async def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus:
|
||||
payment = await get_wallet_payment(wallet_id, payment_hash)
|
||||
async def check_invoice_status(
|
||||
wallet_id: str,
|
||||
payment_hash: str,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> PaymentStatus:
|
||||
payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
if not payment:
|
||||
return PaymentStatus(None)
|
||||
|
||||
|
@@ -66,7 +66,7 @@ async def api_payments_create_invoice():
|
||||
description_hash = b""
|
||||
memo = g.data["memo"]
|
||||
|
||||
try:
|
||||
async with db.connect() as conn:
|
||||
payment_hash, payment_request = await create_invoice(
|
||||
wallet_id=g.wallet.id,
|
||||
amount=g.data["amount"],
|
||||
@@ -74,12 +74,8 @@ async def api_payments_create_invoice():
|
||||
description_hash=description_hash,
|
||||
extra=g.data.get("extra"),
|
||||
webhook=g.data.get("webhook"),
|
||||
conn=conn,
|
||||
)
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
raise exc
|
||||
|
||||
await db.commit()
|
||||
|
||||
invoice = bolt11.decode(payment_request)
|
||||
|
||||
@@ -124,14 +120,14 @@ async def api_payments_create_invoice():
|
||||
async def api_payments_pay_invoice():
|
||||
try:
|
||||
payment_hash = await pay_invoice(
|
||||
wallet_id=g.wallet.id, payment_request=g.data["bolt11"]
|
||||
wallet_id=g.wallet.id,
|
||||
payment_request=g.data["bolt11"],
|
||||
)
|
||||
except ValueError as e:
|
||||
return jsonify({"message": str(e)}), HTTPStatus.BAD_REQUEST
|
||||
except PermissionError as e:
|
||||
return jsonify({"message": str(e)}), HTTPStatus.FORBIDDEN
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
raise exc
|
||||
|
||||
return (
|
||||
@@ -217,23 +213,19 @@ async def api_payments_pay_lnurl():
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
try:
|
||||
extra = {}
|
||||
extra = {}
|
||||
|
||||
if params.get("successAction"):
|
||||
extra["success_action"] = params["successAction"]
|
||||
if g.data["comment"]:
|
||||
extra["comment"] = g.data["comment"]
|
||||
if params.get("successAction"):
|
||||
extra["success_action"] = params["successAction"]
|
||||
if g.data["comment"]:
|
||||
extra["comment"] = g.data["comment"]
|
||||
|
||||
payment_hash = await pay_invoice(
|
||||
wallet_id=g.wallet.id,
|
||||
payment_request=params["pr"],
|
||||
description=g.data.get("description", ""),
|
||||
extra=extra,
|
||||
)
|
||||
except Exception as exc:
|
||||
await db.rollback()
|
||||
raise exc
|
||||
payment_hash = await pay_invoice(
|
||||
wallet_id=g.wallet.id,
|
||||
payment_request=params["pr"],
|
||||
description=g.data.get("description", ""),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
return (
|
||||
jsonify(
|
||||
|
@@ -13,11 +13,10 @@ from quart import (
|
||||
)
|
||||
from lnurl import LnurlResponse, LnurlWithdrawResponse, decode as decode_lnurl # type: ignore
|
||||
|
||||
from lnbits.core import core_app
|
||||
from lnbits.core import core_app, db
|
||||
from lnbits.decorators import check_user_exists, validate_uuids
|
||||
from lnbits.settings import LNBITS_ALLOWED_USERS, SERVICE_FEE
|
||||
|
||||
from .. import db
|
||||
from ..crud import (
|
||||
create_account,
|
||||
get_user,
|
||||
@@ -152,10 +151,10 @@ async def lnurlwallet():
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
account = await create_account()
|
||||
user = await get_user(account.id)
|
||||
wallet = await create_wallet(user_id=user.id)
|
||||
await db.commit()
|
||||
async with db.connect() as conn:
|
||||
account = await create_account(conn=conn)
|
||||
user = await get_user(account.id, conn=conn)
|
||||
wallet = await create_wallet(user_id=user.id, conn=conn)
|
||||
|
||||
g.nursery.start_soon(
|
||||
redeem_lnurl_withdraw,
|
||||
|
91
lnbits/db.py
91
lnbits/db.py
@@ -1,57 +1,54 @@
|
||||
import os
|
||||
from typing import Tuple, Optional, Any
|
||||
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
|
||||
import trio
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import create_engine # type: ignore
|
||||
from quart import g
|
||||
from sqlalchemy_aio import TRIO_STRATEGY # type: ignore
|
||||
from sqlalchemy_aio.base import AsyncConnection
|
||||
|
||||
from .settings import LNBITS_DATA_FOLDER
|
||||
|
||||
|
||||
class Connection:
|
||||
def __init__(self, conn: AsyncConnection):
|
||||
self.conn = conn
|
||||
|
||||
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
result = await self.conn.execute(query, values)
|
||||
return await result.fetchall()
|
||||
|
||||
async def fetchone(self, query: str, values: tuple = ()):
|
||||
result = await self.conn.execute(query, values)
|
||||
row = await result.fetchone()
|
||||
await result.close()
|
||||
return row
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
return await self.conn.execute(query, values)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_name: str):
|
||||
self.db_name = db_name
|
||||
db_path = os.path.join(LNBITS_DATA_FOLDER, f"{db_name}.sqlite3")
|
||||
self.engine = create_engine(f"sqlite:///{db_path}", strategy=TRIO_STRATEGY)
|
||||
self.lock = trio.StrictFIFOLock()
|
||||
|
||||
def connect(self):
|
||||
return self.engine.connect()
|
||||
|
||||
def session_connection(self) -> Tuple[Optional[Any], Optional[Any]]:
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
await self.lock.acquire()
|
||||
try:
|
||||
return getattr(g, f"{self.db_name}_conn", None), getattr(
|
||||
g, f"{self.db_name}_txn", None
|
||||
)
|
||||
except RuntimeError:
|
||||
return None, None
|
||||
|
||||
async def begin(self):
|
||||
conn, _ = self.session_connection()
|
||||
if conn:
|
||||
return
|
||||
|
||||
conn = await self.engine.connect()
|
||||
setattr(g, f"{self.db_name}_conn", conn)
|
||||
txn = await conn.begin()
|
||||
setattr(g, f"{self.db_name}_txn", txn)
|
||||
async with self.engine.connect() as conn:
|
||||
async with conn.begin():
|
||||
yield Connection(conn)
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
async def fetchall(self, query: str, values: tuple = ()) -> list:
|
||||
conn, _ = self.session_connection()
|
||||
if conn:
|
||||
result = await conn.execute(query, values)
|
||||
return await result.fetchall()
|
||||
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
return await result.fetchall()
|
||||
|
||||
async def fetchone(self, query: str, values: tuple = ()):
|
||||
conn, _ = self.session_connection()
|
||||
if conn:
|
||||
result = await conn.execute(query, values)
|
||||
row = await result.fetchone()
|
||||
await result.close()
|
||||
return row
|
||||
|
||||
async with self.connect() as conn:
|
||||
result = await conn.execute(query, values)
|
||||
row = await result.fetchone()
|
||||
@@ -59,29 +56,9 @@ class Database:
|
||||
return row
|
||||
|
||||
async def execute(self, query: str, values: tuple = ()):
|
||||
conn, _ = self.session_connection()
|
||||
if conn:
|
||||
return await conn.execute(query, values)
|
||||
|
||||
async with self.connect() as conn:
|
||||
return await conn.execute(query, values)
|
||||
|
||||
async def commit(self):
|
||||
conn, txn = self.session_connection()
|
||||
if conn and txn:
|
||||
await txn.commit()
|
||||
await self.close_session()
|
||||
|
||||
async def rollback(self):
|
||||
conn, txn = self.session_connection()
|
||||
if conn and txn:
|
||||
await txn.rollback()
|
||||
await self.close_session()
|
||||
|
||||
async def close_session(self):
|
||||
conn, txn = self.session_connection()
|
||||
if conn and txn:
|
||||
await txn.close()
|
||||
await conn.close()
|
||||
delattr(g, f"{self.db_name}_conn")
|
||||
delattr(g, f"{self.db_name}_txn")
|
||||
@asynccontextmanager
|
||||
async def reuse_conn(self, conn: Connection):
|
||||
yield conn
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import math
|
||||
from quart import g, jsonify, request
|
||||
from quart import jsonify, request
|
||||
from http import HTTPStatus
|
||||
import traceback
|
||||
|
||||
@@ -29,7 +29,6 @@ from .helpers import (
|
||||
# Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs.
|
||||
@bleskomat_ext.route("/u", methods=["GET"])
|
||||
async def api_bleskomat_lnurl():
|
||||
|
||||
try:
|
||||
query = request.args.to_dict()
|
||||
|
||||
@@ -125,7 +124,7 @@ async def api_bleskomat_lnurl():
|
||||
|
||||
except LnurlHttpError as e:
|
||||
return jsonify({"status": "ERROR", "reason": str(e)}), e.http_status
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return (
|
||||
jsonify({"status": "ERROR", "reason": "Unexpected error"}),
|
||||
|
@@ -61,7 +61,7 @@ class BleskomatLnurl(NamedTuple):
|
||||
raise LnurlValidationError("Multiple payment requests not supported")
|
||||
try:
|
||||
invoice = bolt11.decode(pr)
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
raise LnurlValidationError(
|
||||
'Invalid parameter ("pr"): Lightning payment request expected'
|
||||
)
|
||||
@@ -79,14 +79,11 @@ class BleskomatLnurl(NamedTuple):
|
||||
async def execute_action(self, query: Dict[str, str]):
|
||||
self.validate_action(query)
|
||||
used = False
|
||||
if self.initial_uses > 0:
|
||||
await db.commit()
|
||||
await db.begin()
|
||||
used = await self.use()
|
||||
if not used:
|
||||
await db.rollback()
|
||||
raise LnurlValidationError("Maximum number of uses already reached")
|
||||
try:
|
||||
async with db.connect() as conn:
|
||||
if self.initial_uses > 0:
|
||||
used = await self.use(conn)
|
||||
if not used:
|
||||
raise LnurlValidationError("Maximum number of uses already reached")
|
||||
tag = self.tag
|
||||
if tag == "withdrawRequest":
|
||||
payment_hash = await pay_invoice(
|
||||
@@ -95,16 +92,10 @@ class BleskomatLnurl(NamedTuple):
|
||||
)
|
||||
if not payment_hash:
|
||||
raise LnurlValidationError("Failed to pay invoice")
|
||||
except Exception as e:
|
||||
if used:
|
||||
await db.rollback()
|
||||
raise e
|
||||
if used:
|
||||
await db.commit()
|
||||
|
||||
async def use(self) -> bool:
|
||||
async def use(self, conn) -> bool:
|
||||
now = int(time.time())
|
||||
result = await db.execute(
|
||||
result = await conn.execute(
|
||||
"""
|
||||
UPDATE bleskomat_lnurls
|
||||
SET remaining_uses = remaining_uses - 1, updated_time = ?
|
||||
|
@@ -70,11 +70,10 @@ async def api_bleskomat_retrieve(bleskomat_id):
|
||||
}
|
||||
)
|
||||
async def api_bleskomat_create_or_update(bleskomat_id=None):
|
||||
|
||||
try:
|
||||
fiat_currency = g.data["fiat_currency"]
|
||||
exchange_rate_provider = g.data["exchange_rate_provider"]
|
||||
rate = await fetch_fiat_exchange_rate(
|
||||
await fetch_fiat_exchange_rate(
|
||||
currency=fiat_currency, provider=exchange_rate_provider
|
||||
)
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user