[perf] Performance bust paginated search (#3543)

This commit is contained in:
Vlad Stan
2025-11-21 18:29:23 +02:00
committed by GitHub
parent b1a7692ce4
commit 44985cb0d1
6 changed files with 43 additions and 13 deletions

View File

@@ -90,11 +90,12 @@ async def get_user_assets(
filters = filters or Filters() filters = filters or Filters()
filters.sortby = filters.sortby or "created_at" filters.sortby = filters.sortby or "created_at"
return await (conn or db).fetch_page( return await (conn or db).fetch_page(
query="SELECT * from assets", query="SELECT * FROM assets",
where=["user_id = :user_id"], where=["user_id = :user_id"],
values={"user_id": user_id}, values={"user_id": user_id},
filters=filters, filters=filters,
model=AssetInfo, model=AssetInfo,
table_name="assets",
) )

View File

@@ -16,11 +16,12 @@ async def get_audit_entries(
conn: Connection | None = None, conn: Connection | None = None,
) -> Page[AuditEntry]: ) -> Page[AuditEntry]:
return await (conn or db).fetch_page( return await (conn or db).fetch_page(
"SELECT * from audit", "SELECT * FROM audit",
[], [],
{}, {},
filters=filters, filters=filters,
model=AuditEntry, model=AuditEntry,
table_name="audit",
) )

View File

@@ -181,6 +181,7 @@ async def get_payments_paginated( # noqa: C901
values, values,
filters=filters, filters=filters,
model=Payment, model=Payment,
table_name="apipayments",
) )

View File

@@ -89,6 +89,7 @@ async def get_accounts(
filters=filters, filters=filters,
model=AccountOverview, model=AccountOverview,
group_by=["accounts.id"], group_by=["accounts.id"],
table_name="accounts",
) )

View File

@@ -187,6 +187,7 @@ async def get_wallets_paginated(
values={"user": user_id, "deleted": deleted}, values={"user": user_id, "deleted": deleted},
filters=filters, filters=filters,
model=Wallet, model=Wallet,
table_name="wallets",
) )
wallets.data = await get_source_wallets(wallets.data, conn) wallets.data = await get_source_wallets(wallets.data, conn)

View File

@@ -222,7 +222,22 @@ class Connection(Compat):
filters: Filters | None = None, filters: Filters | None = None,
model: type[TModel] | None = None, model: type[TModel] | None = None,
group_by: list[str] | None = None, group_by: list[str] | None = None,
table_name: str | None = None,
) -> Page[TModel]: ) -> Page[TModel]:
"""
Parameters:
query: The main SQL query string to execute for data retrieval.
where: list of additional WHERE clause conditions to filter results.
values: dictionary of parameter values to be used in the SQL query.
filters: object for advanced filtering, sorting, and pagination logic.
model: pydantic model type to map query results into model instances.
group_by: list of column names to group results by in the SQL query.
table_name: if provided some optimisations can be applied.
"""
if table_name and not _valid_sql_name(table_name):
raise ValueError(f"Invalid table name: '{table_name}'.")
if not filters: if not filters:
filters = Filters() filters = Filters()
clause = filters.where(where) clause = filters.where(where)
@@ -231,9 +246,7 @@ class Connection(Compat):
group_by_string = "" group_by_string = ""
if group_by: if group_by:
for field in group_by: for field in group_by:
if not re.fullmatch( if not _valid_sql_name(field):
r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field
):
raise ValueError("Value for GROUP BY is invalid") raise ValueError("Value for GROUP BY is invalid")
group_by_string = f"GROUP BY {', '.join(group_by)}" group_by_string = f"GROUP BY {', '.join(group_by)}"
@@ -251,16 +264,17 @@ class Connection(Compat):
if rows: if rows:
# no need for extra query if no pagination is specified # no need for extra query if no pagination is specified
if filters.offset or filters.limit: if filters.offset or filters.limit:
result = await self.execute( if table_name:
f""" count_query = f"SELECT COUNT(*) as count FROM {table_name} {clause}" # noqa: S608
SELECT COUNT(*) as count FROM ( else:
count_query = f"""SELECT COUNT(*) as count
FROM (
{query} {query}
{clause} {clause}
{group_by_string} {group_by_string}
) as count ) as count""" # noqa: S608
""", # noqa: S608
parsed_values, result = await self.execute(count_query, parsed_values)
)
row = result.mappings().first() row = result.mappings().first()
result.close() result.close()
count = int(row.get("count", 0)) count = int(row.get("count", 0))
@@ -393,9 +407,12 @@ class Database(Compat):
filters: Filters | None = None, filters: Filters | None = None,
model: type[TModel] | None = None, model: type[TModel] | None = None,
group_by: list[str] | None = None, group_by: list[str] | None = None,
table_name: str | None = None,
) -> Page[TModel]: ) -> Page[TModel]:
async with self.connect() as conn: async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model, group_by) return await conn.fetch_page(
query, where, values, filters, model, group_by, table_name
)
async def execute(self, query: str, values: dict | None = None): async def execute(self, query: str, values: dict | None = None):
async with self.connect() as conn: async with self.connect() as conn:
@@ -750,3 +767,11 @@ def _safe_load_json(value: str) -> dict:
# DB is corrupted if it gets here # DB is corrupted if it gets here
logger.error(f"Failed to decode JSON: '{value}'") logger.error(f"Failed to decode JSON: '{value}'")
return {} return {}
def _valid_sql_name(name: str) -> bool:
"""Check if a SQL name is valid (alphanumeric and underscores only)"""
return (
re.fullmatch(r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", name)
is not None
)