[fix] bandit sql warnings (#3242)

This commit is contained in:
Vlad Stan
2025-07-05 12:12:47 +03:00
committed by dni ⚡
parent e0749e186e
commit 76ecf113c3
8 changed files with 99 additions and 58 deletions

View File

@@ -30,10 +30,11 @@ async def delete_expired_audit_entries(
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
): ):
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
DELETE from audit DELETE from audit
WHERE delete_at < {db.timestamp_now} WHERE delete_at < {db.timestamp_now}
""", """, # noqa: S608
) )
@@ -48,13 +49,16 @@ async def get_count_stats(
filters = Filters() filters = Filters()
clause = filters.where() clause = filters.where()
data = await (conn or db).fetchall( data = await (conn or db).fetchall(
# SQL injection vectors safety:
# - `field` is a static string, not user input
# - `clause` is generated from filters, which are validated and sanitized
query=f""" query=f"""
SELECT {field} as field, count({field}) as total SELECT {field} as field, count({field}) as total
FROM audit FROM audit
{clause} {clause}
GROUP BY {field} GROUP BY {field}
ORDER BY {field} ORDER BY {field}
""", """, # noqa: S608
values=filters.values(), values=filters.values(),
model=AuditCountStat, model=AuditCountStat,
) )
@@ -70,13 +74,15 @@ async def get_long_duration_stats(
filters = Filters() filters = Filters()
clause = filters.where() clause = filters.where()
long_duration_paths = await (conn or db).fetchall( long_duration_paths = await (conn or db).fetchall(
# This query is safe from SQL injection
# The `clause` is constructed from sanitized inputs
query=f""" query=f"""
SELECT path as field, max(duration) as total FROM audit SELECT path as field, max(duration) as total FROM audit
{clause} {clause}
GROUP BY path GROUP BY path
ORDER BY total DESC ORDER BY total DESC
LIMIT 5 LIMIT 5
""", """, # noqa: S608
values=filters.values(), values=filters.values(),
model=AuditCountStat, model=AuditCountStat,
) )

View File

@@ -78,10 +78,13 @@ async def get_installed_extensions(
active: Optional[bool] = None, active: Optional[bool] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> list[InstallableExtension]: ) -> list[InstallableExtension]:
where = "WHERE active = :active" if active is not None else "" query = "SELECT * FROM installed_extensions"
if active is not None:
query += " WHERE active = :active"
values = {"active": active} if active is not None else {} values = {"active": active} if active is not None else {}
all_extensions = await (conn or db).fetchall( all_extensions = await (conn or db).fetchall(
f"SELECT * FROM installed_extensions {where}", query,
values, values,
model=InstallableExtension, model=InstallableExtension,
) )

View File

@@ -50,11 +50,13 @@ async def get_standalone_payment(
clause = f"({clause}) AND wallet_id = :wallet_id" clause = f"({clause}) AND wallet_id = :wallet_id"
row = await (conn or db).fetchone( row = await (conn or db).fetchone(
# This query is safe from SQL injection
# The `clause` is constructed from sanitized inputs
f""" f"""
SELECT * FROM apipayments SELECT * FROM apipayments
WHERE {clause} WHERE {clause}
ORDER BY amount LIMIT 1 ORDER BY amount LIMIT 1
""", """, # noqa: S608
values, values,
Payment, Payment,
) )
@@ -80,14 +82,20 @@ async def get_latest_payments_by_extension(
ext_name: str, ext_id: str, limit: int = 5 ext_name: str, ext_id: str, limit: int = 5
) -> list[Payment]: ) -> list[Payment]:
return await db.fetchall( return await db.fetchall(
# This query is safe from SQL injection
# The limtit is an integer and not user input
f""" f"""
SELECT * FROM apipayments SELECT * FROM apipayments
WHERE status = '{PaymentState.SUCCESS}' WHERE status = :status
AND extra LIKE :ext_name AND extra LIKE :ext_name
AND extra LIKE :ext_id AND extra LIKE :ext_id
ORDER BY time DESC LIMIT {int(limit)} ORDER BY time DESC LIMIT {int(limit)}
""", """, # noqa: S608
{"ext_name": f"%{ext_name}%", "ext_id": f"%{ext_id}%"}, {
"status": f"{PaymentState.SUCCESS}",
"ext_name": f"%{ext_name}%",
"ext_id": f"%{ext_id}%",
},
Payment, Payment,
) )
@@ -227,21 +235,23 @@ async def delete_expired_invoices(
# first we delete all invoices older than one month # first we delete all invoices older than one month
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
DELETE FROM apipayments DELETE FROM apipayments
WHERE status = '{PaymentState.PENDING}' AND amount > 0 WHERE status = :status AND amount > 0
AND time < {db.timestamp_placeholder("delta")} AND time < {db.timestamp_placeholder("delta")}
""", """, # noqa: S608
{"delta": int(time() - 2592000)}, {"status": f"{PaymentState.PENDING}", "delta": int(time() - 2592000)},
) )
# then we delete all invoices whose expiry date is in the past # then we delete all invoices whose expiry date is in the past
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
DELETE FROM apipayments DELETE FROM apipayments
WHERE status = '{PaymentState.PENDING}' AND amount > 0 WHERE status = :status AND amount > 0
AND expiry < {db.timestamp_placeholder("now")} AND expiry < {db.timestamp_placeholder("now")}
""", """, # noqa: S608
{"now": int(time())}, {"status": f"{PaymentState.PENDING}", "now": int(time())},
) )
@@ -321,16 +331,20 @@ async def get_payments_history(
) )
""" """
] ]
clause = filters.where(where)
transactions: list[dict] = await db.fetchall( transactions: list[dict] = await db.fetchall(
# This query is safe from SQL injection:
# - `date_trunc` is a static string, not user input
# - `clause` is generated from filters, which are validated and sanitized
f""" f"""
SELECT {date_trunc} date, SELECT {date_trunc} date,
SUM(CASE WHEN amount > 0 THEN amount ELSE 0 END) income, SUM(CASE WHEN amount > 0 THEN amount ELSE 0 END) income,
SUM(CASE WHEN amount < 0 THEN abs(amount) + abs(fee) ELSE 0 END) spending SUM(CASE WHEN amount < 0 THEN abs(amount) + abs(fee) ELSE 0 END) spending
FROM apipayments FROM apipayments
{filters.where(where)} {clause}
GROUP BY date GROUP BY date
ORDER BY date DESC ORDER BY date DESC
""", """, # noqa: S608
filters.values(values), filters.values(values),
) )
if wallet_id: if wallet_id:
@@ -376,13 +390,16 @@ async def get_payment_count_stats(
clause = filters.where(extra_stmts) clause = filters.where(extra_stmts)
data = await (conn or db).fetchall( data = await (conn or db).fetchall(
# SQL injection vectors safety:
# - `field` is a static string, not user input
# - `clause` is generated from filters, which are validated and sanitized
query=f""" query=f"""
SELECT {field} as field, count(*) as total SELECT {field} as field, count(*) as total
FROM apipayments FROM apipayments
{clause} {clause}
GROUP BY {field} GROUP BY {field}
ORDER BY {field} ORDER BY {field}
""", """, # noqa: S608
values=filters.values(), values=filters.values(),
model=PaymentCountStat, model=PaymentCountStat,
) )
@@ -468,6 +485,8 @@ async def get_wallets_stats(
clauses = filters.where(where_stmts) clauses = filters.where(where_stmts)
data = await (conn or db).fetchall( data = await (conn or db).fetchall(
# This query is safe from SQL injection
# The `clauses` is constructed from sanitized inputs
query=f""" query=f"""
SELECT apipayments.wallet_id, SELECT apipayments.wallet_id,
MAX(wallets.name) AS wallet_name, MAX(wallets.name) AS wallet_name,
@@ -479,7 +498,7 @@ async def get_wallets_stats(
{clauses} {clauses}
GROUP BY apipayments.wallet_id GROUP BY apipayments.wallet_id
ORDER BY payments_count ORDER BY payments_count
""", """, # noqa: S608
values=filters.values(), values=filters.values(),
model=PaymentWalletStats, model=PaymentWalletStats,
) )
@@ -504,11 +523,11 @@ async def check_internal(
otherwise None otherwise None
""" """
return await (conn or db).fetchone( return await (conn or db).fetchone(
f""" """
SELECT * FROM apipayments SELECT * FROM apipayments
WHERE payment_hash = :hash AND status = '{PaymentState.PENDING}' AND amount > 0 WHERE payment_hash = :hash AND status = :status AND amount > 0
""", """,
{"hash": payment_hash}, {"status": f"{PaymentState.PENDING}", "hash": payment_hash},
Payment, Payment,
) )

View File

@@ -110,6 +110,7 @@ async def delete_accounts_no_wallets(
) -> None: ) -> None:
delta = int(time()) - time_delta delta = int(time()) - time_delta
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
DELETE FROM accounts DELETE FROM accounts
WHERE NOT EXISTS ( WHERE NOT EXISTS (
@@ -118,7 +119,7 @@ async def delete_accounts_no_wallets(
(updated_at is null AND created_at < :delta) (updated_at is null AND created_at < :delta)
OR updated_at < {db.timestamp_placeholder("delta")} OR updated_at < {db.timestamp_placeholder("delta")}
) )
""", """, # noqa: S608
{"delta": delta}, {"delta": delta},
) )

View File

@@ -48,11 +48,12 @@ async def delete_wallet(
) -> None: ) -> None:
now = int(time()) now = int(time())
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE wallets UPDATE wallets
SET deleted = :deleted, updated_at = {db.timestamp_placeholder('now')} SET deleted = :deleted, updated_at = {db.timestamp_placeholder('now')}
WHERE id = :wallet AND "user" = :user WHERE id = :wallet AND "user" = :user
""", """, # noqa: S608
{"wallet": wallet_id, "user": user_id, "deleted": deleted, "now": now}, {"wallet": wallet_id, "user": user_id, "deleted": deleted, "now": now},
) )
@@ -71,11 +72,12 @@ async def delete_wallet_by_id(
) -> Optional[int]: ) -> Optional[int]:
now = int(time()) now = int(time())
result = await (conn or db).execute( result = await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE wallets UPDATE wallets
SET deleted = true, updated_at = {db.timestamp_placeholder('now')} SET deleted = true, updated_at = {db.timestamp_placeholder('now')}
WHERE id = :wallet WHERE id = :wallet
""", """, # noqa: S608
{"wallet": wallet_id, "now": now}, {"wallet": wallet_id, "now": now},
) )
return result.rowcount return result.rowcount
@@ -107,14 +109,16 @@ async def delete_unused_wallets(
async def get_wallet( async def get_wallet(
wallet_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None wallet_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> Optional[Wallet]: ) -> Optional[Wallet]:
where = "AND deleted = :deleted" if deleted is not None else "" query = """
return await (conn or db).fetchone(
f"""
SELECT *, COALESCE(( SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0) AS balance_msat FROM wallets ), 0) AS balance_msat FROM wallets
WHERE id = :wallet {where} WHERE id = :wallet
""", """
if deleted is not None:
query += " AND deleted = :deleted "
return await (conn or db).fetchone(
query,
{"wallet": wallet_id, "deleted": deleted}, {"wallet": wallet_id, "deleted": deleted},
Wallet, Wallet,
) )
@@ -123,14 +127,16 @@ async def get_wallet(
async def get_wallets( async def get_wallets(
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> list[Wallet]: ) -> list[Wallet]:
where = "AND deleted = :deleted" if deleted is not None else "" query = """
return await (conn or db).fetchall(
f"""
SELECT *, COALESCE(( SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0) AS balance_msat FROM wallets ), 0) AS balance_msat FROM wallets
WHERE "user" = :user {where} WHERE "user" = :user
""", """
if deleted is not None:
query += " AND deleted = :deleted "
return await (conn or db).fetchall(
query,
{"user": user_id, "deleted": deleted}, {"user": user_id, "deleted": deleted},
Wallet, Wallet,
) )
@@ -162,12 +168,11 @@ async def get_wallets_paginated(
async def get_wallets_ids( async def get_wallets_ids(
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> list[str]: ) -> list[str]:
where = "AND deleted = :deleted" if deleted is not None else "" query = """SELECT id FROM wallets WHERE "user" = :user"""
if deleted is not None:
query += "AND deleted = :deleted"
result: list[dict] = await (conn or db).fetchall( result: list[dict] = await (conn or db).fetchall(
f""" query,
SELECT id FROM wallets
WHERE "user" = :user {where}
""",
{"user": user_id, "deleted": deleted}, {"user": user_id, "deleted": deleted},
) )
return [row["id"] for row in result] return [row["id"] for row in result]

View File

@@ -214,6 +214,7 @@ async def m007_set_invoice_expiries(db: Connection):
""" """
try: try:
result = await db.execute( result = await db.execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
SELECT bolt11, checking_id SELECT bolt11, checking_id
FROM apipayments FROM apipayments
@@ -222,7 +223,7 @@ async def m007_set_invoice_expiries(db: Connection):
AND bolt11 IS NOT NULL AND bolt11 IS NOT NULL
AND expiry IS NULL AND expiry IS NULL
AND time < {db.timestamp_now} AND time < {db.timestamp_now}
""" """ # noqa: S608
) )
rows = result.mappings().all() rows = result.mappings().all()
if len(rows): if len(rows):
@@ -242,10 +243,11 @@ async def m007_set_invoice_expiries(db: Connection):
f" {invoice.payment_hash} to {expiration_date}" f" {invoice.payment_hash} to {expiration_date}"
) )
await db.execute( await db.execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE apipayments SET expiry = {db.timestamp_placeholder('expiry')} UPDATE apipayments SET expiry = {db.timestamp_placeholder('expiry')}
WHERE checking_id = :checking_id AND amount > 0 WHERE checking_id = :checking_id AND amount > 0
""", """, # noqa: S608
{"expiry": expiration_date, "checking_id": checking_id}, {"expiry": expiration_date, "checking_id": checking_id},
) )
except Exception as exc: except Exception as exc:
@@ -456,17 +458,19 @@ async def m017_add_timestamp_columns_to_accounts_and_wallets(db: Connection):
# set all to now where they are null # set all to now where they are null
now = int(time()) now = int(time())
await db.execute( await db.execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE wallets SET created_at = {db.timestamp_placeholder('now')} UPDATE wallets SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL WHERE created_at IS NULL
""", """, # noqa: S608
{"now": now}, {"now": now},
) )
await db.execute( await db.execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE accounts SET created_at = {db.timestamp_placeholder('now')} UPDATE accounts SET created_at = {db.timestamp_placeholder('now')}
WHERE created_at IS NULL WHERE created_at IS NULL
""", """, # noqa: S608
{"now": now}, {"now": now},
) )
@@ -618,7 +622,12 @@ async def m027_update_apipayments_data(db: Connection):
logger.info(f"Updating {offset} to {offset+limit}") logger.info(f"Updating {offset} to {offset+limit}")
result = await db.execute( result = await db.execute(
f"SELECT * FROM apipayments ORDER BY time LIMIT {limit} OFFSET {offset}" # Limit and Offset safe from SQL injection
# since they are integers and are not user input
f"""
SELECT * FROM apipayments
ORDER BY time LIMIT {int(limit)} OFFSET {int(offset)}
""" # noqa: S608
) )
payments = result.mappings().all() payments = result.mappings().all()
logger.info(f"Payments count: {len(payments)}") logger.info(f"Payments count: {len(payments)}")
@@ -631,11 +640,12 @@ async def m027_update_apipayments_data(db: Connection):
tag = extra.get("tag") tag = extra.get("tag")
tsph = db.timestamp_placeholder("created_at") tsph = db.timestamp_placeholder("created_at")
await db.execute( await db.execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
UPDATE apipayments UPDATE apipayments
SET tag = :tag, created_at = {tsph}, updated_at = {tsph} SET tag = :tag, created_at = {tsph}, updated_at = {tsph}
WHERE checking_id = :checking_id WHERE checking_id = :checking_id
""", """, # noqa: S608
{ {
"tag": tag, "tag": tag,
"created_at": created_at, "created_at": created_at,

View File

@@ -252,7 +252,7 @@ class Connection(Compat):
{clause} {clause}
{group_by_string} {group_by_string}
) as count ) as count
""", """, # noqa: S608
parsed_values, parsed_values,
) )
row = result.mappings().first() row = result.mappings().first()
@@ -597,7 +597,7 @@ def insert_query(table_name: str, model: BaseModel) -> str:
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword) # add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields = ", ".join([f'"{key}"' for key in keys]) fields = ", ".join([f'"{key}"' for key in keys])
values = ", ".join(placeholders) values = ", ".join(placeholders)
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})" return f"INSERT INTO {table_name} ({fields}) VALUES ({values})" # noqa: S608
def update_query( def update_query(
@@ -615,7 +615,7 @@ def update_query(
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword) # add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields.append(f'"{field}" = {placeholder}') fields.append(f'"{field}" = {placeholder}')
query = ", ".join(fields) query = ", ".join(fields)
return f"UPDATE {table_name} SET {query} {where}" return f"UPDATE {table_name} SET {query} {where}" # noqa: S608
def model_to_dict(model: BaseModel) -> dict: def model_to_dict(model: BaseModel) -> dict:

View File

@@ -221,10 +221,10 @@ classmethod-decorators = [
# S602 `subprocess` call with `shell=True` identified, security issue # S602 `subprocess` call with `shell=True` identified, security issue
# S603 `subprocess` call: check for execution of untrusted input # S603 `subprocess` call: check for execution of untrusted input
# S607: Starting a process with a partial executable path # S607: Starting a process with a partial executable path
# TODO: do not skip S608:
# S608: Possible SQL injection vector through string-based query construction # S608: Possible SQL injection vector through string-based query construction
# S324 Probable use of insecure hash functions in `hashlib`: `md5` # S324 Probable use of insecure hash functions in `hashlib`: `md5`
"lnbits/*" = ["S101", "S608"] # TODO: remove S101 ignore
"lnbits/*" = ["S101"]
"lnbits/core/views/admin_api.py" = ["S602", "S603", "S607"] "lnbits/core/views/admin_api.py" = ["S602", "S603", "S607"]
"crypto.py" = ["S324"] "crypto.py" = ["S324"]
"test*.py" = ["S101", "S105", "S106", "S307"] "test*.py" = ["S101", "S105", "S106", "S307"]
@@ -232,9 +232,6 @@ classmethod-decorators = [
"tests/*" = ["S311"] "tests/*" = ["S311"]
"tests/regtest/helpers.py" = ["S603"] "tests/regtest/helpers.py" = ["S603"]
[tool.bandit]
skips = ["B101", "B404"]
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]
max-complexity = 10 max-complexity = 10