mega chore: update sqlalchemy (#2611)

* update sqlalchemy to 1.4
* async postgres

---------

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
This commit is contained in:
dni ⚡ 2024-09-24 10:56:03 +02:00 committed by GitHub
parent c637e8d31e
commit 21d87adc52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1020 additions and 951 deletions

View File

@ -46,7 +46,10 @@ runs:
- name: Install the project dependencies - name: Install the project dependencies
shell: bash shell: bash
run: poetry install run: |
poetry install
# needed for conv tests
poetry add psycopg2-binary
- name: Use Node.js ${{ inputs.node-version }} - name: Use Node.js ${{ inputs.node-version }}
if: ${{ (inputs.npm == 'true') }} if: ${{ (inputs.npm == 'true') }}

View File

@ -30,6 +30,7 @@
meta.rev = self.dirtyRev or self.rev; meta.rev = self.dirtyRev or self.rev;
meta.mainProgram = projectName; meta.mainProgram = projectName;
overrides = pkgs.poetry2nix.overrides.withDefaults (final: prev: { overrides = pkgs.poetry2nix.overrides.withDefaults (final: prev: {
coincurve = prev.coincurve.override { preferWheel = true; };
protobuf = prev.protobuf.override { preferWheel = true; }; protobuf = prev.protobuf.override { preferWheel = true; };
ruff = prev.ruff.override { preferWheel = true; }; ruff = prev.ruff.override { preferWheel = true; };
wallycore = prev.wallycore.override { preferWheel = true; }; wallycore = prev.wallycore.override { preferWheel = true; };

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,3 @@
import datetime
from time import time from time import time
from loguru import logger from loguru import logger
@ -102,7 +101,7 @@ async def m002_add_fields_to_apipayments(db):
import json import json
rows = await (await db.execute("SELECT * FROM apipayments")).fetchall() rows = await db.fetchall("SELECT * FROM apipayments")
for row in rows: for row in rows:
if not row["memo"] or not row["memo"].startswith("#"): if not row["memo"] or not row["memo"].startswith("#"):
continue continue
@ -113,15 +112,15 @@ async def m002_add_fields_to_apipayments(db):
new = row["memo"][len(prefix) :] new = row["memo"][len(prefix) :]
await db.execute( await db.execute(
""" """
UPDATE apipayments SET extra = ?, memo = ? UPDATE apipayments SET extra = :extra, memo = :memo1
WHERE checking_id = ? AND memo = ? WHERE checking_id = :checking_id AND memo = :memo2
""", """,
( {
json.dumps({"tag": ext}), "extra": json.dumps({"tag": ext}),
new, "memo1": new,
row["checking_id"], "checking_id": row["checking_id"],
row["memo"], "memo2": row["memo"],
), },
) )
break break
except OperationalError: except OperationalError:
@ -212,19 +211,17 @@ async def m007_set_invoice_expiries(db):
Precomputes invoice expiry for existing pending incoming payments. Precomputes invoice expiry for existing pending incoming payments.
""" """
try: try:
rows = await ( rows = await db.fetchall(
await db.execute( f"""
f""" SELECT bolt11, checking_id
SELECT bolt11, checking_id FROM apipayments
FROM apipayments WHERE pending = true
WHERE pending = true AND amount > 0
AND amount > 0 AND bolt11 IS NOT NULL
AND bolt11 IS NOT NULL AND expiry IS NULL
AND expiry IS NULL AND time < {db.timestamp_now}
AND time < {db.timestamp_now} """
""" )
)
).fetchall()
if len(rows): if len(rows):
logger.info(f"Migration: Checking expiry of {len(rows)} invoices") logger.info(f"Migration: Checking expiry of {len(rows)} invoices")
for i, ( for i, (
@ -236,22 +233,17 @@ async def m007_set_invoice_expiries(db):
if invoice.expiry is None: if invoice.expiry is None:
continue continue
expiration_date = datetime.datetime.fromtimestamp( expiration_date = invoice.date + invoice.expiry
invoice.date + invoice.expiry
)
logger.info( logger.info(
f"Migration: {i+1}/{len(rows)} setting expiry of invoice" f"Migration: {i+1}/{len(rows)} setting expiry of invoice"
f" {invoice.payment_hash} to {expiration_date}" f" {invoice.payment_hash} to {expiration_date}"
) )
await db.execute( await db.execute(
""" f"""
UPDATE apipayments SET expiry = ? UPDATE apipayments SET expiry = {db.timestamp_placeholder('expiry')}
WHERE checking_id = ? AND amount > 0 WHERE checking_id = :checking_id AND amount > 0
""", """,
( {"expiry": expiration_date, "checking_id": checking_id},
db.datetime_to_timestamp(expiration_date),
checking_id,
),
) )
except Exception: except Exception:
continue continue
@ -347,17 +339,15 @@ async def m014_set_deleted_wallets(db):
Sets deleted column to wallets. Sets deleted column to wallets.
""" """
try: try:
rows = await ( rows = await db.fetchall(
await db.execute( """
""" SELECT *
SELECT * FROM wallets
FROM wallets WHERE user LIKE 'del:%'
WHERE user LIKE 'del:%' AND adminkey LIKE 'del:%'
AND adminkey LIKE 'del:%' AND inkey LIKE 'del:%'
AND inkey LIKE 'del:%' """
""" )
)
).fetchall()
for row in rows: for row in rows:
try: try:
@ -367,10 +357,15 @@ async def m014_set_deleted_wallets(db):
await db.execute( await db.execute(
""" """
UPDATE wallets SET UPDATE wallets SET
"user" = ?, adminkey = ?, inkey = ?, deleted = true "user" = :user, adminkey = :adminkey, inkey = :inkey, deleted = true
WHERE id = ? WHERE id = :wallet
""", """,
(user, adminkey, inkey, row[0]), {
"user": user,
"adminkey": adminkey,
"inkey": inkey,
"wallet": row.get("id"),
},
) )
except Exception: except Exception:
continue continue
@ -456,17 +451,17 @@ async def m017_add_timestamp_columns_to_accounts_and_wallets(db):
now = int(time()) now = int(time())
await db.execute( await db.execute(
f""" f"""
UPDATE wallets SET created_at = {db.timestamp_placeholder} UPDATE wallets SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL WHERE created_at IS NULL
""", """,
(now,), {"now": now},
) )
await db.execute( await db.execute(
f""" f"""
UPDATE accounts SET created_at = {db.timestamp_placeholder} UPDATE accounts SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL WHERE created_at IS NULL
""", """,
(now,), {"now": now},
) )
except OperationalError as exc: except OperationalError as exc:

View File

@ -7,7 +7,6 @@ import json
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from sqlite3 import Row
from typing import Callable, Optional from typing import Callable, Optional
from ecdsa import SECP256k1, SigningKey from ecdsa import SECP256k1, SigningKey
@ -240,7 +239,7 @@ class Payment(FromRowModel):
return self.status == PaymentState.FAILED.value return self.status == PaymentState.FAILED.value
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: dict):
return cls( return cls(
checking_id=row["checking_id"], checking_id=row["checking_id"],
payment_hash=row["hash"] or "0" * 64, payment_hash=row["hash"] or "0" * 64,
@ -347,7 +346,7 @@ class TinyURL(BaseModel):
time: float time: float
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: dict):
return cls(**dict(row)) return cls(**dict(row))

View File

@ -7,14 +7,13 @@ import re
import time import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum from enum import Enum
from sqlite3 import Row
from typing import Any, Generic, Literal, Optional, TypeVar from typing import Any, Generic, Literal, Optional, TypeVar
from loguru import logger from loguru import logger
from pydantic import BaseModel, ValidationError, root_validator from pydantic import BaseModel, ValidationError, root_validator
from sqlalchemy import create_engine from sqlalchemy import event
from sqlalchemy_aio.base import AsyncConnection from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY from sqlalchemy.sql import text
from lnbits.settings import settings from lnbits.settings import settings
@ -24,31 +23,15 @@ SQLITE = "SQLITE"
if settings.lnbits_database_url: if settings.lnbits_database_url:
database_uri = settings.lnbits_database_url database_uri = settings.lnbits_database_url
if database_uri.startswith("cockroachdb://"): if database_uri.startswith("cockroachdb://"):
DB_TYPE = COCKROACH DB_TYPE = COCKROACH
else: else:
if not database_uri.startswith("postgres://"):
raise ValueError(
"Please use the 'postgres://...' " "format for the database URL."
)
DB_TYPE = POSTGRES DB_TYPE = POSTGRES
from psycopg2.extensions import DECIMAL, new_type, register_type
def _parse_timestamp(value, _):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return time.mktime(datetime.datetime.strptime(value, f).timetuple())
register_type(
new_type(
DECIMAL.values,
"DEC2FLOAT",
lambda value, curs: float(value) if value is not None else None,
)
)
register_type(new_type((1184, 1114), "TIMESTAMP2INT", _parse_timestamp))
else: else:
if not os.path.isdir(settings.lnbits_data_folder): if not os.path.isdir(settings.lnbits_data_folder):
os.mkdir(settings.lnbits_data_folder) os.mkdir(settings.lnbits_data_folder)
@ -56,21 +39,21 @@ else:
DB_TYPE = SQLITE DB_TYPE = SQLITE
def compat_timestamp_placeholder(): def compat_timestamp_placeholder(key: str):
if DB_TYPE == POSTGRES: if DB_TYPE == POSTGRES:
return "to_timestamp(?)" return f"to_timestamp(:{key})"
elif DB_TYPE == COCKROACH: elif DB_TYPE == COCKROACH:
return "cast(? AS timestamp)" return f"cast(:{key} AS timestamp)"
else: else:
return "?" return f":{key}"
def get_placeholder(model: Any, field: str) -> str: def get_placeholder(model: Any, field: str) -> str:
type_ = model.__fields__[field].type_ type_ = model.__fields__[field].type_
if type_ == datetime.datetime: if type_ == datetime.datetime:
return compat_timestamp_placeholder() return compat_timestamp_placeholder(field)
else: else:
return "?" return f":{field}"
class Compat: class Compat:
@ -127,15 +110,13 @@ class Compat:
return "BIGINT" return "BIGINT"
return "INT" return "INT"
@property def timestamp_placeholder(self, key: str) -> str:
def timestamp_placeholder(self) -> str: return compat_timestamp_placeholder(key)
return compat_timestamp_placeholder()
class Connection(Compat): class Connection(Compat):
def __init__(self, conn: AsyncConnection, txn, typ, name, schema): def __init__(self, conn: AsyncConnection, typ, name, schema):
self.conn = conn self.conn = conn
self.txn = txn
self.type = typ self.type = typ
self.name = name self.name = name
self.schema = schema self.schema = schema
@ -146,45 +127,42 @@ class Connection(Compat):
query = query.replace("?", "%s") query = query.replace("?", "%s")
return query return query
def rewrite_values(self, values): def rewrite_values(self, values: dict) -> dict:
# strip html # strip html
clean_regex = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});") clean_regex = re.compile("<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});")
clean_values: dict = {}
# tuple to list and back to tuple for key, raw_value in values.items():
raw_values = [values] if isinstance(values, str) else list(values)
values = []
for raw_value in raw_values:
if isinstance(raw_value, str): if isinstance(raw_value, str):
values.append(re.sub(clean_regex, "", raw_value)) clean_values[key] = re.sub(clean_regex, "", raw_value)
elif isinstance(raw_value, datetime.datetime): elif isinstance(raw_value, datetime.datetime):
ts = raw_value.timestamp() ts = raw_value.timestamp()
if self.type == SQLITE: if self.type == SQLITE:
values.append(int(ts)) clean_values[key] = int(ts)
else: else:
values.append(ts) clean_values[key] = ts
else: else:
values.append(raw_value) clean_values[key] = raw_value
return tuple(values) return clean_values
async def fetchall(self, query: str, values: tuple = ()) -> list: async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
result = await self.conn.execute( params = self.rewrite_values(values) if values else {}
self.rewrite_query(query), self.rewrite_values(values) result = await self.conn.execute(text(self.rewrite_query(query)), params)
) row = result.mappings().all()
return await result.fetchall() result.close()
return row
async def fetchone(self, query: str, values: tuple = ()): async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
result = await self.conn.execute( params = self.rewrite_values(values) if values else {}
self.rewrite_query(query), self.rewrite_values(values) result = await self.conn.execute(text(self.rewrite_query(query)), params)
) row = result.mappings().first()
row = await result.fetchone() result.close()
await result.close()
return row return row
async def fetch_page( async def fetch_page(
self, self,
query: str, query: str,
where: Optional[list[str]] = None, where: Optional[list[str]] = None,
values: Optional[list[str]] = None, values: Optional[dict] = None,
filters: Optional[Filters] = None, filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None, model: Optional[type[TRowModel]] = None,
group_by: Optional[list[str]] = None, group_by: Optional[list[str]] = None,
@ -211,14 +189,14 @@ class Connection(Compat):
{filters.order_by()} {filters.order_by()}
{filters.pagination()} {filters.pagination()}
""", """,
parsed_values, self.rewrite_values(parsed_values),
) )
if rows: if rows:
# no need for extra query if no pagination is specified # no need for extra query if no pagination is specified
if filters.offset or filters.limit: if filters.offset or filters.limit:
count = await self.fetchone( result = await self.fetchone(
f""" f"""
SELECT COUNT(*) FROM ( SELECT COUNT(*) as count FROM (
{query} {query}
{clause} {clause}
{group_by_string} {group_by_string}
@ -226,21 +204,22 @@ class Connection(Compat):
""", """,
parsed_values, parsed_values,
) )
count = int(count[0]) count = int(result.get("count", 0))
else: else:
count = len(rows) count = len(rows)
else: else:
count = 0 count = 0
return Page( return Page(
data=[model.from_row(row) for row in rows] if model else rows, data=[model.from_row(row) for row in rows] if model else [],
total=count, total=count,
) )
async def execute(self, query: str, values: tuple = ()): async def execute(self, query: str, values: Optional[dict] = None):
return await self.conn.execute( params = self.rewrite_values(values) if values else {}
self.rewrite_query(query), self.rewrite_values(values) result = await self.conn.execute(text(self.rewrite_query(query)), params)
) await self.conn.commit()
return result
class Database(Compat): class Database(Compat):
@ -253,18 +232,44 @@ class Database(Compat):
self.path = os.path.join( self.path = os.path.join(
settings.lnbits_data_folder, f"{self.name}.sqlite3" settings.lnbits_data_folder, f"{self.name}.sqlite3"
) )
database_uri = f"sqlite:///{self.path}" database_uri = f"sqlite+aiosqlite:///{self.path}"
else: else:
database_uri = settings.lnbits_database_url database_uri = settings.lnbits_database_url.replace(
"postgres://", "postgresql+asyncpg://"
)
if self.name.startswith("ext_"): if self.name.startswith("ext_"):
self.schema = self.name[4:] self.schema = self.name[4:]
else: else:
self.schema = None self.schema = None
self.engine = create_engine( self.engine: AsyncEngine = create_async_engine(
database_uri, strategy=ASYNCIO_STRATEGY, echo=settings.debug_database database_uri, echo=settings.debug_database
) )
if self.type in {POSTGRES, COCKROACH}:
@event.listens_for(self.engine.sync_engine, "connect")
def register_custom_types(dbapi_connection, *_):
def _parse_timestamp(value):
if value is None:
return None
f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value:
f = "%Y-%m-%d %H:%M:%S"
return int(
time.mktime(datetime.datetime.strptime(value, f).timetuple())
)
dbapi_connection.run_async(
lambda connection: connection.set_type_codec(
"TIMESTAMP",
encoder=datetime.datetime,
decoder=_parse_timestamp,
schema="pg_catalog",
)
)
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
logger.trace(f"database {self.type} added for {self.name}") logger.trace(f"database {self.type} added for {self.name}")
@ -273,41 +278,37 @@ class Database(Compat):
async def connect(self): async def connect(self):
await self.lock.acquire() await self.lock.acquire()
try: try:
async with self.engine.connect() as conn: # type: ignore async with self.engine.connect() as conn:
async with conn.begin() as txn: if not conn:
wconn = Connection(conn, txn, self.type, self.name, self.schema) raise Exception("Could not connect to the database")
if self.schema: wconn = Connection(conn, self.type, self.name, self.schema)
if self.type in {POSTGRES, COCKROACH}:
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(
f"ATTACH '{self.path}' AS {self.schema}"
)
yield wconn if self.schema:
if self.type in {POSTGRES, COCKROACH}:
await wconn.execute(
f"CREATE SCHEMA IF NOT EXISTS {self.schema}"
)
elif self.type == SQLITE:
await wconn.execute(f"ATTACH '{self.path}' AS {self.schema}")
yield wconn
finally: finally:
self.lock.release() self.lock.release()
async def fetchall(self, query: str, values: tuple = ()) -> list: async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]:
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) return await conn.fetchall(query, values)
return await result.fetchall()
async def fetchone(self, query: str, values: tuple = ()): async def fetchone(self, query: str, values: Optional[dict] = None) -> dict:
async with self.connect() as conn: async with self.connect() as conn:
result = await conn.execute(query, values) return await conn.fetchone(query, values)
row = await result.fetchone()
await result.close()
return row
async def fetch_page( async def fetch_page(
self, self,
query: str, query: str,
where: Optional[list[str]] = None, where: Optional[list[str]] = None,
values: Optional[list[str]] = None, values: Optional[dict] = None,
filters: Optional[Filters] = None, filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None, model: Optional[type[TRowModel]] = None,
group_by: Optional[list[str]] = None, group_by: Optional[list[str]] = None,
@ -315,7 +316,7 @@ class Database(Compat):
async with self.connect() as conn: async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model, group_by) return await conn.fetch_page(query, where, values, filters, model, group_by)
async def execute(self, query: str, values: tuple = ()): async def execute(self, query: str, values: Optional[dict] = None):
async with self.connect() as conn: async with self.connect() as conn:
return await conn.execute(query, values) return await conn.execute(query, values)
@ -373,8 +374,8 @@ class Operator(Enum):
class FromRowModel(BaseModel): class FromRowModel(BaseModel):
@classmethod @classmethod
def from_row(cls, row: Row): def from_row(cls, row: dict):
return cls(**dict(row)) return cls(**row)
class FilterModel(BaseModel): class FilterModel(BaseModel):
@ -396,12 +397,13 @@ class Page(BaseModel, Generic[T]):
class Filter(BaseModel, Generic[TFilterModel]): class Filter(BaseModel, Generic[TFilterModel]):
field: str field: str
op: Operator = Operator.EQ op: Operator = Operator.EQ
values: list[Any]
model: Optional[type[TFilterModel]] model: Optional[type[TFilterModel]]
values: Optional[dict] = None
@classmethod @classmethod
def parse_query(cls, key: str, raw_values: list[Any], model: type[TFilterModel]): def parse_query(
cls, key: str, raw_values: list[Any], model: type[TFilterModel], i: int = 0
):
# Key format: # Key format:
# key[operator] # key[operator]
# e.g. name[eq] # e.g. name[eq]
@ -417,12 +419,12 @@ class Filter(BaseModel, Generic[TFilterModel]):
if field in model.__fields__: if field in model.__fields__:
compare_field = model.__fields__[field] compare_field = model.__fields__[field]
values = [] values: dict = {}
for raw_value in raw_values: for raw_value in raw_values:
validated, errors = compare_field.validate(raw_value, {}, loc="none") validated, errors = compare_field.validate(raw_value, {}, loc="none")
if errors: if errors:
raise ValidationError(errors=[errors], model=model) raise ValidationError(errors=[errors], model=model)
values.append(validated) values[f"{field}__{i}"] = validated
else: else:
raise ValueError("Unknown filter field") raise ValueError("Unknown filter field")
@ -430,13 +432,17 @@ class Filter(BaseModel, Generic[TFilterModel]):
@property @property
def statement(self): def statement(self):
assert self.model, "Model is required for statement generation" stmt = []
placeholder = get_placeholder(self.model, self.field) for key in self.values.keys() if self.values else []:
if self.op in (Operator.INCLUDE, Operator.EXCLUDE): clean_key = key.split("__")[0]
placeholders = ", ".join([placeholder] * len(self.values)) if (
stmt = [f"{self.field} {self.op.as_sql} ({placeholders})"] self.model
else: and self.model.__fields__[clean_key].type_ == datetime.datetime
stmt = [f"{self.field} {self.op.as_sql} {placeholder}"] * len(self.values) ):
placeholder = compat_timestamp_placeholder(key)
else:
placeholder = f":{key}"
stmt.append(f"{clean_key} {self.op.as_sql} {placeholder}")
return " OR ".join(stmt) return " OR ".join(stmt)
@ -487,14 +493,11 @@ class Filters(BaseModel, Generic[TFilterModel]):
for page_filter in self.filters: for page_filter in self.filters:
where_stmts.append(page_filter.statement) where_stmts.append(page_filter.statement)
if self.search and self.model: if self.search and self.model:
fields = self.model.__search_fields__
if DB_TYPE == POSTGRES: if DB_TYPE == POSTGRES:
where_stmts.append( where_stmts.append(f"lower(concat({', '.join(fields)})) LIKE :search")
f"lower(concat({', '.join(self.model.__search_fields__)})) LIKE ?"
)
elif DB_TYPE == SQLITE: elif DB_TYPE == SQLITE:
where_stmts.append( where_stmts.append(f"lower({'||'.join(fields)}) LIKE :search")
f"lower({'||'.join(self.model.__search_fields__)}) LIKE ?"
)
if where_stmts: if where_stmts:
return "WHERE " + " AND ".join(where_stmts) return "WHERE " + " AND ".join(where_stmts)
return "" return ""
@ -504,12 +507,14 @@ class Filters(BaseModel, Generic[TFilterModel]):
return f"ORDER BY {self.sortby} {self.direction or 'asc'}" return f"ORDER BY {self.sortby} {self.direction or 'asc'}"
return "" return ""
def values(self, values: Optional[list[str]] = None) -> tuple: def values(self, values: Optional[dict] = None) -> dict:
if not values: if not values:
values = [] values = {}
if self.filters: if self.filters:
for page_filter in self.filters: for page_filter in self.filters:
values.extend(page_filter.values) if page_filter.values:
for key, value in page_filter.values.items():
values[key] = value
if self.search and self.model: if self.search and self.model:
values.append(f"%{self.search}%") values["search"] = f"%{self.search}%"
return tuple(values) return values

View File

@ -204,9 +204,9 @@ def parse_filters(model: Type[TFilterModel]):
): ):
params = request.query_params params = request.query_params
filters = [] filters = []
for key in params.keys(): for i, key in enumerate(params.keys()):
try: try:
filters.append(Filter.parse_query(key, params.getlist(key), model)) filters.append(Filter.parse_query(key, params.getlist(key), model, i))
except ValueError: except ValueError:
continue continue

View File

@ -187,12 +187,14 @@ def insert_query(table_name: str, model: BaseModel) -> str:
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})" return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
def update_query(table_name: str, model: BaseModel, where: str = "WHERE id = ?") -> str: def update_query(
table_name: str, model: BaseModel, where: str = "WHERE id = :id"
) -> str:
""" """
Generate an update query with placeholders for a given table and model Generate an update query with placeholders for a given table and model
:param table_name: Name of the table :param table_name: Name of the table
:param model: Pydantic model :param model: Pydantic model
:param where: Where string, default to `WHERE id = ?` :param where: Where string, default to `WHERE id = :id`
""" """
fields = [] fields = []
for field in model.dict().keys(): for field in model.dict().keys():

View File

@ -74,10 +74,10 @@ def configure_logger() -> None:
logging.getLogger("uvicorn.error").propagate = False logging.getLogger("uvicorn.error").propagate = False
logging.getLogger("sqlalchemy").handlers = [InterceptHandler()] logging.getLogger("sqlalchemy").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base").handlers = [InterceptHandler()] logging.getLogger("sqlalchemy.engine").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base").propagate = False logging.getLogger("sqlalchemy.engine").propagate = False
logging.getLogger("sqlalchemy.engine.base.Engine").handlers = [InterceptHandler()] logging.getLogger("sqlalchemy.engine.Engine").handlers = [InterceptHandler()]
logging.getLogger("sqlalchemy.engine.base.Engine").propagate = False logging.getLogger("sqlalchemy.engine.Engine").propagate = False
class Formatter: class Formatter:

997
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -16,25 +16,25 @@ python = "^3.12 | ^3.11 | ^3.10 | ^3.9"
bech32 = "1.2.0" bech32 = "1.2.0"
click = "8.1.7" click = "8.1.7"
ecdsa = "0.19.0" ecdsa = "0.19.0"
fastapi = "0.112.0" fastapi = "0.113.0"
httpx = "0.27.0" httpx = "0.27.0"
jinja2 = "3.1.4" jinja2 = "3.1.4"
lnurl = "0.5.3" lnurl = "0.5.3"
psycopg2-binary = "2.9.9" pydantic = "1.10.18"
pydantic = "1.10.17"
pyqrcode = "1.2.1" pyqrcode = "1.2.1"
shortuuid = "1.0.13" shortuuid = "1.0.13"
sqlalchemy = "1.3.24"
sqlalchemy-aio = "0.17.0"
sse-starlette = "1.8.2" sse-starlette = "1.8.2"
typing-extensions = "4.12.2" typing-extensions = "4.12.2"
uvicorn = "0.30.5" uvicorn = "0.30.6"
sqlalchemy = "1.4.54"
aiosqlite = "0.20.0"
asyncpg = "0.29.0"
uvloop = "0.19.0" uvloop = "0.19.0"
websockets = "11.0.3" websockets = "11.0.3"
loguru = "0.7.2" loguru = "0.7.2"
grpcio = "1.65.5" grpcio = "1.66.1"
protobuf = "5.27.3" protobuf = "5.28.0"
pyln-client = "24.5" pyln-client = "24.8.1"
pywebpush = "1.14.1" pywebpush = "1.14.1"
slowapi = "0.1.9" slowapi = "0.1.9"
websocket-client = "1.8.0" websocket-client = "1.8.0"
@ -70,11 +70,11 @@ black = "^24.8.0"
pytest-asyncio = "^0.21.2" pytest-asyncio = "^0.21.2"
pytest = "^8.3.2" pytest = "^8.3.2"
pytest-cov = "^4.1.0" pytest-cov = "^4.1.0"
mypy = "^1.11.1" mypy = "^1.11.2"
types-protobuf = "^5.27.0.20240626" types-protobuf = "^5.27.0.20240626"
pre-commit = "^3.8.0" pre-commit = "^3.8.0"
openapi-spec-validator = "^0.7.1" openapi-spec-validator = "^0.7.1"
ruff = "^0.5.7" ruff = "^0.6.4"
types-passlib = "^1.7.7.20240327" types-passlib = "^1.7.7.20240327"
openai = "^1.39.0" openai = "^1.39.0"
json5 = "^0.9.25" json5 = "^0.9.25"
@ -84,7 +84,7 @@ pytest-httpserver = "^1.1.0"
pytest-mock = "^3.14.0" pytest-mock = "^3.14.0"
types-mock = "^5.1.0.20240425" types-mock = "^5.1.0.20240425"
mock = "^5.1.0" mock = "^5.1.0"
grpcio-tools = "^1.65.5" grpcio-tools = "^1.66.1"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
@ -126,7 +126,6 @@ module = [
"secp256k1.*", "secp256k1.*",
"uvicorn.*", "uvicorn.*",
"sqlalchemy.*", "sqlalchemy.*",
"sqlalchemy_aio.*",
"websocket.*", "websocket.*",
"websockets.*", "websockets.*",
"pyqrcode.*", "pyqrcode.*",
@ -136,7 +135,6 @@ module = [
"bolt11.*", "bolt11.*",
"bitstring.*", "bitstring.*",
"ecdsa.*", "ecdsa.*",
"psycopg2.*",
"pyngrok.*", "pyngrok.*",
"pyln.client.*", "pyln.client.*",
"py_vapid.*", "py_vapid.*",

View File

@ -367,11 +367,11 @@ async def test_get_payments_history(client, adminkey_headers_from, fake_payments
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 1 assert len(data) == 1
assert data[0]["spending"] == sum(
payment.amount * 1000 for payment in fake_data if payment.out
)
assert data[0]["income"] == sum( assert data[0]["income"] == sum(
payment.amount * 1000 for payment in fake_data if not payment.out [int(payment.amount * 1000) for payment in fake_data if not payment.out]
)
assert data[0]["spending"] == sum(
[int(payment.amount * 1000) for payment in fake_data if payment.out]
) )
response = await client.get( response = await client.get(

View File

@ -25,7 +25,6 @@ from lnbits.core.views.payment_api import api_payments_create_invoice
from lnbits.db import DB_TYPE, SQLITE, Database from lnbits.db import DB_TYPE, SQLITE, Database
from lnbits.settings import settings from lnbits.settings import settings
from tests.helpers import ( from tests.helpers import (
clean_database,
get_random_invoice_data, get_random_invoice_data,
) )
@ -47,7 +46,6 @@ def event_loop():
# use session scope to run once before and once after all tests # use session scope to run once before and once after all tests
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def app(): async def app():
clean_database(settings)
app = create_app() app = create_app()
async with LifespanManager(app) as manager: async with LifespanManager(app) as manager:
settings.first_install = False settings.first_install = False
@ -199,9 +197,9 @@ async def fake_payments(client, adminkey_headers_from):
"/api/v1/payments", headers=adminkey_headers_from, json=invoice.dict() "/api/v1/payments", headers=adminkey_headers_from, json=invoice.dict()
) )
assert response.is_success assert response.is_success
await update_payment_status( data = response.json()
response.json()["checking_id"], status=PaymentState.SUCCESS assert data["checking_id"]
) await update_payment_status(data["checking_id"], status=PaymentState.SUCCESS)
params = {"time[ge]": ts, "time[le]": time()} params = {"time[ge]": ts, "time[le]": time()}
return fake_data, params return fake_data, params

View File

@ -2,11 +2,7 @@ import random
import string import string
from typing import Optional from typing import Optional
from psycopg2 import connect from lnbits.db import FromRowModel
from psycopg2.errors import InvalidCatalogName
from lnbits import core
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
from lnbits.wallets import get_funding_source, set_funding_source from lnbits.wallets import get_funding_source, set_funding_source
@ -35,21 +31,3 @@ set_funding_source()
funding_source = get_funding_source() funding_source = get_funding_source()
is_fake: bool = funding_source.__class__.__name__ == "FakeWallet" is_fake: bool = funding_source.__class__.__name__ == "FakeWallet"
is_regtest: bool = not is_fake is_regtest: bool = not is_fake
def clean_database(settings):
if DB_TYPE == POSTGRES:
conn = connect(settings.lnbits_database_url)
conn.autocommit = True
with conn.cursor() as cur:
try:
cur.execute("DROP DATABASE lnbits_test")
except InvalidCatalogName:
pass
cur.execute("CREATE DATABASE lnbits_test")
core.db.__init__("database")
conn.close()
else:
# TODO: do this once mock data is removed from test data folder
# os.remove(settings.lnbits_data_folder + "/database.sqlite3")
pass

View File

@ -14,8 +14,8 @@ from lnbits.db import POSTGRES
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_date_conversion(db): async def test_date_conversion(db):
if db.type == POSTGRES: if db.type == POSTGRES:
row = await db.fetchone("SELECT now()::date") row = await db.fetchone("SELECT now()::date as now")
assert row and isinstance(row[0], date) assert row and isinstance(row.get("now"), date)
# make test to create wallet and delete wallet # make test to create wallet and delete wallet

View File

@ -12,10 +12,17 @@ test = DbTestModel(id=1, name="test", value="yes")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_helpers_insert_query(): async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test) q = insert_query("test_helpers_query", test)
assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)" assert (
q == "INSERT INTO test_helpers_query (id, name, value) "
"VALUES (:id, :name, :value)"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_helpers_update_query(): async def test_helpers_update_query():
q = update_query("test_helpers_query", test) q = update_query("test_helpers_query", test)
assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?" assert (
q == "UPDATE test_helpers_query "
"SET id = :id, name = :name, value = :value "
"WHERE id = :id"
)

View File

@ -1,5 +1,5 @@
# Python script to migrate an LNbits SQLite DB to Postgres # Python script to migrate an LNbits SQLite DB to Postgres
# All credits to @Fritz446 for the awesome work # credits to @Fritz446 for the awesome work
# pip install psycopg2 OR psycopg2-binary # pip install psycopg2 OR psycopg2-binary
@ -9,10 +9,14 @@ import sqlite3
import sys import sys
from typing import List, Optional from typing import List, Optional
import psycopg2
from lnbits.settings import settings from lnbits.settings import settings
try:
import psycopg2 # type: ignore
except ImportError:
print("Please install psycopg2")
sys.exit(1)
sqfolder = settings.lnbits_data_folder sqfolder = settings.lnbits_data_folder
db_url = settings.lnbits_database_url db_url = settings.lnbits_database_url
@ -55,8 +59,8 @@ def check_db_versions(sqdb):
version = dbpost[key] version = dbpost[key]
if value != version: if value != version:
raise Exception( raise Exception(
f"sqlite database version ({value}) of {key} doesn't match postgres" f"sqlite database version ({value}) of {key} doesn't match "
f" database version {version}" f"postgres database version {version}"
) )
connection = postgres.connection connection = postgres.connection