mirror of
https://github.com/lnbits/lnbits.git
synced 2025-11-24 12:57:43 +01:00
[perf] Performance bust paginated search (#3543)
This commit is contained in:
@@ -90,11 +90,12 @@ async def get_user_assets(
|
|||||||
filters = filters or Filters()
|
filters = filters or Filters()
|
||||||
filters.sortby = filters.sortby or "created_at"
|
filters.sortby = filters.sortby or "created_at"
|
||||||
return await (conn or db).fetch_page(
|
return await (conn or db).fetch_page(
|
||||||
query="SELECT * from assets",
|
query="SELECT * FROM assets",
|
||||||
where=["user_id = :user_id"],
|
where=["user_id = :user_id"],
|
||||||
values={"user_id": user_id},
|
values={"user_id": user_id},
|
||||||
filters=filters,
|
filters=filters,
|
||||||
model=AssetInfo,
|
model=AssetInfo,
|
||||||
|
table_name="assets",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,11 +16,12 @@ async def get_audit_entries(
|
|||||||
conn: Connection | None = None,
|
conn: Connection | None = None,
|
||||||
) -> Page[AuditEntry]:
|
) -> Page[AuditEntry]:
|
||||||
return await (conn or db).fetch_page(
|
return await (conn or db).fetch_page(
|
||||||
"SELECT * from audit",
|
"SELECT * FROM audit",
|
||||||
[],
|
[],
|
||||||
{},
|
{},
|
||||||
filters=filters,
|
filters=filters,
|
||||||
model=AuditEntry,
|
model=AuditEntry,
|
||||||
|
table_name="audit",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ async def get_payments_paginated( # noqa: C901
|
|||||||
values,
|
values,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
model=Payment,
|
model=Payment,
|
||||||
|
table_name="apipayments",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ async def get_accounts(
|
|||||||
filters=filters,
|
filters=filters,
|
||||||
model=AccountOverview,
|
model=AccountOverview,
|
||||||
group_by=["accounts.id"],
|
group_by=["accounts.id"],
|
||||||
|
table_name="accounts",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ async def get_wallets_paginated(
|
|||||||
values={"user": user_id, "deleted": deleted},
|
values={"user": user_id, "deleted": deleted},
|
||||||
filters=filters,
|
filters=filters,
|
||||||
model=Wallet,
|
model=Wallet,
|
||||||
|
table_name="wallets",
|
||||||
)
|
)
|
||||||
|
|
||||||
wallets.data = await get_source_wallets(wallets.data, conn)
|
wallets.data = await get_source_wallets(wallets.data, conn)
|
||||||
|
|||||||
47
lnbits/db.py
47
lnbits/db.py
@@ -222,7 +222,22 @@ class Connection(Compat):
|
|||||||
filters: Filters | None = None,
|
filters: Filters | None = None,
|
||||||
model: type[TModel] | None = None,
|
model: type[TModel] | None = None,
|
||||||
group_by: list[str] | None = None,
|
group_by: list[str] | None = None,
|
||||||
|
table_name: str | None = None,
|
||||||
) -> Page[TModel]:
|
) -> Page[TModel]:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
query: The main SQL query string to execute for data retrieval.
|
||||||
|
where: list of additional WHERE clause conditions to filter results.
|
||||||
|
values: dictionary of parameter values to be used in the SQL query.
|
||||||
|
filters: object for advanced filtering, sorting, and pagination logic.
|
||||||
|
model: pydantic model type to map query results into model instances.
|
||||||
|
group_by: list of column names to group results by in the SQL query.
|
||||||
|
table_name: if provided some optimisations can be applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if table_name and not _valid_sql_name(table_name):
|
||||||
|
raise ValueError(f"Invalid table name: '{table_name}'.")
|
||||||
|
|
||||||
if not filters:
|
if not filters:
|
||||||
filters = Filters()
|
filters = Filters()
|
||||||
clause = filters.where(where)
|
clause = filters.where(where)
|
||||||
@@ -231,9 +246,7 @@ class Connection(Compat):
|
|||||||
group_by_string = ""
|
group_by_string = ""
|
||||||
if group_by:
|
if group_by:
|
||||||
for field in group_by:
|
for field in group_by:
|
||||||
if not re.fullmatch(
|
if not _valid_sql_name(field):
|
||||||
r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field
|
|
||||||
):
|
|
||||||
raise ValueError("Value for GROUP BY is invalid")
|
raise ValueError("Value for GROUP BY is invalid")
|
||||||
group_by_string = f"GROUP BY {', '.join(group_by)}"
|
group_by_string = f"GROUP BY {', '.join(group_by)}"
|
||||||
|
|
||||||
@@ -251,16 +264,17 @@ class Connection(Compat):
|
|||||||
if rows:
|
if rows:
|
||||||
# no need for extra query if no pagination is specified
|
# no need for extra query if no pagination is specified
|
||||||
if filters.offset or filters.limit:
|
if filters.offset or filters.limit:
|
||||||
result = await self.execute(
|
if table_name:
|
||||||
f"""
|
count_query = f"SELECT COUNT(*) as count FROM {table_name} {clause}" # noqa: S608
|
||||||
SELECT COUNT(*) as count FROM (
|
else:
|
||||||
|
count_query = f"""SELECT COUNT(*) as count
|
||||||
|
FROM (
|
||||||
{query}
|
{query}
|
||||||
{clause}
|
{clause}
|
||||||
{group_by_string}
|
{group_by_string}
|
||||||
) as count
|
) as count""" # noqa: S608
|
||||||
""", # noqa: S608
|
|
||||||
parsed_values,
|
result = await self.execute(count_query, parsed_values)
|
||||||
)
|
|
||||||
row = result.mappings().first()
|
row = result.mappings().first()
|
||||||
result.close()
|
result.close()
|
||||||
count = int(row.get("count", 0))
|
count = int(row.get("count", 0))
|
||||||
@@ -393,9 +407,12 @@ class Database(Compat):
|
|||||||
filters: Filters | None = None,
|
filters: Filters | None = None,
|
||||||
model: type[TModel] | None = None,
|
model: type[TModel] | None = None,
|
||||||
group_by: list[str] | None = None,
|
group_by: list[str] | None = None,
|
||||||
|
table_name: str | None = None,
|
||||||
) -> Page[TModel]:
|
) -> Page[TModel]:
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
return await conn.fetch_page(query, where, values, filters, model, group_by)
|
return await conn.fetch_page(
|
||||||
|
query, where, values, filters, model, group_by, table_name
|
||||||
|
)
|
||||||
|
|
||||||
async def execute(self, query: str, values: dict | None = None):
|
async def execute(self, query: str, values: dict | None = None):
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
@@ -750,3 +767,11 @@ def _safe_load_json(value: str) -> dict:
|
|||||||
# DB is corrupted if it gets here
|
# DB is corrupted if it gets here
|
||||||
logger.error(f"Failed to decode JSON: '{value}'")
|
logger.error(f"Failed to decode JSON: '{value}'")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_sql_name(name: str) -> bool:
|
||||||
|
"""Check if a SQL name is valid (alphanumeric and underscores only)"""
|
||||||
|
return (
|
||||||
|
re.fullmatch(r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", name)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user