From 44985cb0d109ede4deb992fd6ae39ffa06dc82ab Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Fri, 21 Nov 2025 18:29:23 +0200 Subject: [PATCH] [perf] Performance bust paginated search (#3543) --- lnbits/core/crud/assets.py | 3 ++- lnbits/core/crud/audit.py | 3 ++- lnbits/core/crud/payments.py | 1 + lnbits/core/crud/users.py | 1 + lnbits/core/crud/wallets.py | 1 + lnbits/db.py | 47 +++++++++++++++++++++++++++--------- 6 files changed, 43 insertions(+), 13 deletions(-) diff --git a/lnbits/core/crud/assets.py b/lnbits/core/crud/assets.py index 5bfa0d356..215ee3ac4 100644 --- a/lnbits/core/crud/assets.py +++ b/lnbits/core/crud/assets.py @@ -90,11 +90,12 @@ async def get_user_assets( filters = filters or Filters() filters.sortby = filters.sortby or "created_at" return await (conn or db).fetch_page( - query="SELECT * from assets", + query="SELECT * FROM assets", where=["user_id = :user_id"], values={"user_id": user_id}, filters=filters, model=AssetInfo, + table_name="assets", ) diff --git a/lnbits/core/crud/audit.py b/lnbits/core/crud/audit.py index a84f941c1..f65ada5c8 100644 --- a/lnbits/core/crud/audit.py +++ b/lnbits/core/crud/audit.py @@ -16,11 +16,12 @@ async def get_audit_entries( conn: Connection | None = None, ) -> Page[AuditEntry]: return await (conn or db).fetch_page( - "SELECT * from audit", + "SELECT * FROM audit", [], {}, filters=filters, model=AuditEntry, + table_name="audit", ) diff --git a/lnbits/core/crud/payments.py b/lnbits/core/crud/payments.py index 8204bd4c1..e11cd035e 100644 --- a/lnbits/core/crud/payments.py +++ b/lnbits/core/crud/payments.py @@ -181,6 +181,7 @@ async def get_payments_paginated( # noqa: C901 values, filters=filters, model=Payment, + table_name="apipayments", ) diff --git a/lnbits/core/crud/users.py b/lnbits/core/crud/users.py index 6f9142351..0620c2a11 100644 --- a/lnbits/core/crud/users.py +++ b/lnbits/core/crud/users.py @@ -89,6 +89,7 @@ async def get_accounts( filters=filters, model=AccountOverview, group_by=["accounts.id"], + table_name="accounts", ) diff --git a/lnbits/core/crud/wallets.py b/lnbits/core/crud/wallets.py index 67d63eaa7..b13d83c27 100644 --- a/lnbits/core/crud/wallets.py +++ b/lnbits/core/crud/wallets.py @@ -187,6 +187,7 @@ async def get_wallets_paginated( values={"user": user_id, "deleted": deleted}, filters=filters, model=Wallet, + table_name="wallets", ) wallets.data = await get_source_wallets(wallets.data, conn) diff --git a/lnbits/db.py b/lnbits/db.py index f59812460..0e1061981 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -222,7 +222,22 @@ class Connection(Compat): filters: Filters | None = None, model: type[TModel] | None = None, group_by: list[str] | None = None, + table_name: str | None = None, ) -> Page[TModel]: + """ + Parameters: + query: The main SQL query string to execute for data retrieval. + where: list of additional WHERE clause conditions to filter results. + values: dictionary of parameter values to be used in the SQL query. + filters: object for advanced filtering, sorting, and pagination logic. + model: pydantic model type to map query results into model instances. + group_by: list of column names to group results by in the SQL query. + table_name: if provided some optimisations can be applied. + """ + + if table_name and not _valid_sql_name(table_name): + raise ValueError(f"Invalid table name: '{table_name}'.") + if not filters: filters = Filters() clause = filters.where(where) @@ -231,9 +246,7 @@ class Connection(Compat): group_by_string = "" if group_by: for field in group_by: - if not re.fullmatch( - r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field - ): + if not _valid_sql_name(field): raise ValueError("Value for GROUP BY is invalid") group_by_string = f"GROUP BY {', '.join(group_by)}" @@ -251,16 +264,17 @@ class Connection(Compat): if rows: # no need for extra query if no pagination is specified if filters.offset or filters.limit: - result = await self.execute( - f""" - SELECT COUNT(*) as count FROM ( + if table_name: + count_query = f"SELECT COUNT(*) as count FROM {table_name} {clause}" # noqa: S608 + else: + count_query = f"""SELECT COUNT(*) as count + FROM ( {query} {clause} {group_by_string} - ) as count - """, # noqa: S608 - parsed_values, - ) + ) as count""" # noqa: S608 + + result = await self.execute(count_query, parsed_values) row = result.mappings().first() result.close() count = int(row.get("count", 0)) @@ -393,9 +407,12 @@ class Database(Compat): filters: Filters | None = None, model: type[TModel] | None = None, group_by: list[str] | None = None, + table_name: str | None = None, ) -> Page[TModel]: async with self.connect() as conn: - return await conn.fetch_page(query, where, values, filters, model, group_by) + return await conn.fetch_page( + query, where, values, filters, model, group_by, table_name + ) async def execute(self, query: str, values: dict | None = None): async with self.connect() as conn: @@ -750,3 +767,11 @@ def _safe_load_json(value: str) -> dict: # DB is corrupted if it gets here logger.error(f"Failed to decode JSON: '{value}'") return {} + + +def _valid_sql_name(name: str) -> bool: + """Check if a SQL name is valid (alphanumeric and underscores only)""" + return ( + re.fullmatch(r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", name) + is not None + )