mirror of
https://github.com/lnbits/lnbits.git
synced 2025-06-04 12:13:34 +02:00
feat: add group_by to fetch_page (#2140)
--------- Co-authored-by: Pavol Rusnak <pavol@rusnak.io> Co-authored-by: Vlad Stan <stan.v.vlad@gmail.com>
This commit is contained in:
parent
14519135d8
commit
7ce4eddb0e
15
lnbits/db.py
15
lnbits/db.py
@ -179,16 +179,27 @@ class Connection(Compat):
|
|||||||
values: Optional[List[str]] = None,
|
values: Optional[List[str]] = None,
|
||||||
filters: Optional[Filters] = None,
|
filters: Optional[Filters] = None,
|
||||||
model: Optional[Type[TRowModel]] = None,
|
model: Optional[Type[TRowModel]] = None,
|
||||||
|
group_by: Optional[List[str]] = None,
|
||||||
) -> Page[TRowModel]:
|
) -> Page[TRowModel]:
|
||||||
if not filters:
|
if not filters:
|
||||||
filters = Filters()
|
filters = Filters()
|
||||||
clause = filters.where(where)
|
clause = filters.where(where)
|
||||||
parsed_values = filters.values(values)
|
parsed_values = filters.values(values)
|
||||||
|
|
||||||
|
group_by_string = ""
|
||||||
|
if group_by:
|
||||||
|
for field in group_by:
|
||||||
|
if not re.fullmatch(
|
||||||
|
r"[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?", field
|
||||||
|
):
|
||||||
|
raise ValueError("Value for GROUP BY is invalid")
|
||||||
|
group_by_string = f"GROUP BY {', '.join(group_by)}"
|
||||||
|
|
||||||
rows = await self.fetchall(
|
rows = await self.fetchall(
|
||||||
f"""
|
f"""
|
||||||
{query}
|
{query}
|
||||||
{clause}
|
{clause}
|
||||||
|
{group_by_string}
|
||||||
{filters.order_by()}
|
{filters.order_by()}
|
||||||
{filters.pagination()}
|
{filters.pagination()}
|
||||||
""",
|
""",
|
||||||
@ -202,6 +213,7 @@ class Connection(Compat):
|
|||||||
SELECT COUNT(*) FROM (
|
SELECT COUNT(*) FROM (
|
||||||
{query}
|
{query}
|
||||||
{clause}
|
{clause}
|
||||||
|
{group_by_string}
|
||||||
) as count
|
) as count
|
||||||
""",
|
""",
|
||||||
parsed_values,
|
parsed_values,
|
||||||
@ -288,9 +300,10 @@ class Database(Compat):
|
|||||||
values: Optional[List[str]] = None,
|
values: Optional[List[str]] = None,
|
||||||
filters: Optional[Filters] = None,
|
filters: Optional[Filters] = None,
|
||||||
model: Optional[Type[TRowModel]] = None,
|
model: Optional[Type[TRowModel]] = None,
|
||||||
|
group_by: Optional[List[str]] = None,
|
||||||
) -> Page[TRowModel]:
|
) -> Page[TRowModel]:
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
return await conn.fetch_page(query, where, values, filters, model)
|
return await conn.fetch_page(query, where, values, filters, model, group_by)
|
||||||
|
|
||||||
async def execute(self, query: str, values: tuple = ()):
|
async def execute(self, query: str, values: tuple = ()):
|
||||||
async with self.connect() as conn:
|
async with self.connect() as conn:
|
||||||
|
74
tests/core/test_db_fetch_page.py
Normal file
74
tests/core/test_db_fetch_page.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from tests.helpers import DbTestModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def fetch_page(db):
|
||||||
|
await db.execute("DROP TABLE IF EXISTS test_db_fetch_page")
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE test_db_fetch_page (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO test_db_fetch_page (id, name, value) VALUES
|
||||||
|
('1', 'Alice', 'foo'),
|
||||||
|
('2', 'Bob', 'bar'),
|
||||||
|
('3', 'Carol', 'bar'),
|
||||||
|
('4', 'Dave', 'bar'),
|
||||||
|
('5', 'Dave', 'foo')
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
yield
|
||||||
|
await db.execute("DROP TABLE test_db_fetch_page")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_fetch_page_simple(fetch_page, db):
|
||||||
|
row = await db.fetch_page(
|
||||||
|
query="select * from test_db_fetch_page",
|
||||||
|
model=DbTestModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert row
|
||||||
|
assert row.total == 5
|
||||||
|
assert len(row.data) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_fetch_page_group_by(fetch_page, db):
|
||||||
|
row = await db.fetch_page(
|
||||||
|
query="select max(id) as id, name from test_db_fetch_page",
|
||||||
|
model=DbTestModel,
|
||||||
|
group_by=["name"],
|
||||||
|
)
|
||||||
|
assert row
|
||||||
|
assert row.total == 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_fetch_page_group_by_multiple(fetch_page, db):
|
||||||
|
row = await db.fetch_page(
|
||||||
|
query="select max(id) as id, name, value from test_db_fetch_page",
|
||||||
|
model=DbTestModel,
|
||||||
|
group_by=["value", "name"],
|
||||||
|
)
|
||||||
|
assert row
|
||||||
|
assert row.total == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_fetch_page_group_by_evil(fetch_page, db):
|
||||||
|
with pytest.raises(ValueError, match="Value for GROUP BY is invalid"):
|
||||||
|
await db.fetch_page(
|
||||||
|
query="select * from test_db_fetch_page",
|
||||||
|
model=DbTestModel,
|
||||||
|
group_by=["name;"],
|
||||||
|
)
|
@ -1,27 +1,21 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from lnbits.helpers import (
|
from lnbits.helpers import (
|
||||||
insert_query,
|
insert_query,
|
||||||
update_query,
|
update_query,
|
||||||
)
|
)
|
||||||
|
from tests.helpers import DbTestModel
|
||||||
|
|
||||||
|
test = DbTestModel(id=1, name="test", value="yes")
|
||||||
class DbTestModel(BaseModel):
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
test = DbTestModel(id=1, name="test")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_helpers_insert_query():
|
async def test_helpers_insert_query():
|
||||||
q = insert_query("test_helpers_query", test)
|
q = insert_query("test_helpers_query", test)
|
||||||
assert q == "INSERT INTO test_helpers_query (id, name) VALUES (?, ?)"
|
assert q == "INSERT INTO test_helpers_query (id, name, value) VALUES (?, ?, ?)"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_helpers_update_query():
|
async def test_helpers_update_query():
|
||||||
q = update_query("test_helpers_query", test)
|
q = update_query("test_helpers_query", test)
|
||||||
assert q == "UPDATE test_helpers_query SET id = ?, name = ? WHERE id = ?"
|
assert q == "UPDATE test_helpers_query SET id = ?, name = ?, value = ? WHERE id = ?"
|
||||||
|
@ -5,17 +5,23 @@ import random
|
|||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
from subprocess import PIPE, Popen, TimeoutExpired
|
from subprocess import PIPE, Popen, TimeoutExpired
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from psycopg2 import connect
|
from psycopg2 import connect
|
||||||
from psycopg2.errors import InvalidCatalogName
|
from psycopg2.errors import InvalidCatalogName
|
||||||
|
|
||||||
from lnbits import core
|
from lnbits import core
|
||||||
from lnbits.db import DB_TYPE, POSTGRES
|
from lnbits.db import DB_TYPE, POSTGRES, FromRowModel
|
||||||
from lnbits.wallets import get_wallet_class, set_wallet_class
|
from lnbits.wallets import get_wallet_class, set_wallet_class
|
||||||
|
|
||||||
|
|
||||||
|
class DbTestModel(FromRowModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
value: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def get_random_string(N: int = 10):
|
def get_random_string(N: int = 10):
|
||||||
return "".join(
|
return "".join(
|
||||||
random.SystemRandom().choice(string.ascii_uppercase + string.digits)
|
random.SystemRandom().choice(string.ascii_uppercase + string.digits)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user