This commit is contained in:
dni ⚡ 2024-09-26 13:10:01 +02:00 committed by Vlad Stan
parent 185ed0202e
commit 7712d10d83
4 changed files with 52 additions and 48 deletions

View File

@ -13,7 +13,7 @@ from ecdsa import SECP256k1, SigningKey
from fastapi import Query
from pydantic import BaseModel, validator
from lnbits.db import FilterModel, FromRowModel
from lnbits.db import FilterModel
from lnbits.helpers import url_for
from lnbits.lnurl import encode as lnurl_encode
from lnbits.settings import settings
@ -103,7 +103,7 @@ class UserConfig(BaseModel):
provider: Optional[str] = "lnbits" # auth provider
class Account(FromRowModel):
class Account(BaseModel):
id: str
is_super_user: Optional[bool] = False
is_admin: Optional[bool] = False
@ -244,7 +244,7 @@ class CreatePayment(BaseModel):
fee: int = 0
class Payment(FromRowModel):
class Payment(BaseModel):
status: str
# TODO should be removed in the future, backward compatibility
pending: bool

View File

@ -146,7 +146,10 @@ class Connection(Compat):
return clean_values
async def fetchall(
self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> list[TModel]:
params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params)
@ -159,7 +162,10 @@ class Connection(Compat):
return row
async def fetchone(
self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> TModel:
params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params)
@ -187,9 +193,9 @@ class Connection(Compat):
where: Optional[list[str]] = None,
values: Optional[dict] = None,
filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None,
model: Optional[type[TModel]] = None,
group_by: Optional[list[str]] = None,
) -> Page[TRowModel]:
) -> Page[TModel]:
if not filters:
filters = Filters()
clause = filters.where(where)
@ -213,11 +219,12 @@ class Connection(Compat):
{filters.pagination()}
""",
self.rewrite_values(parsed_values),
model,
)
if rows:
# no need for extra query if no pagination is specified
if filters.offset or filters.limit:
result = await self.fetchone(
result = await self.execute(
f"""
SELECT COUNT(*) as count FROM (
{query}
@ -234,7 +241,7 @@ class Connection(Compat):
count = 0
return Page(
data=[model.from_row(row) for row in rows] if model else [],
data=rows,
total=count,
)
@ -343,9 +350,9 @@ class Database(Compat):
where: Optional[list[str]] = None,
values: Optional[dict] = None,
filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None,
model: Optional[type[TModel]] = None,
group_by: Optional[list[str]] = None,
) -> Page[TRowModel]:
) -> Page[TModel]:
async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model, group_by)
@ -405,12 +412,6 @@ class Operator(Enum):
raise ValueError("Unknown SQL Operator")
class FromRowModel(BaseModel):
@classmethod
def from_row(cls, row: dict):
return cls(**row)
class FilterModel(BaseModel):
__search_fields__: list[str] = []
__sort_fields__: Optional[list[str]] = None
@ -418,7 +419,6 @@ class FilterModel(BaseModel):
T = TypeVar("T")
TModel = TypeVar("TModel", bound=BaseModel)
TRowModel = TypeVar("TRowModel", bound=FromRowModel)
TFilterModel = TypeVar("TFilterModel", bound=FilterModel)
@ -585,19 +585,27 @@ def update_query(
def model_to_dict(model: BaseModel) -> dict:
"""
Convert a Pydantic model to a dictionary with JSON-encoded nested models
TODO: no recursion, maybe make them recursive?
"""
_dict = model.dict()
for key, value in _dict.items():
if key.startswith("_"):
continue
type_ = model.__fields__[key].type_
if type_ == BaseModel:
_dict[key] = json.dumps(value.dict())
if type(type_) is type(BaseModel):
_dict[key] = json.dumps(value)
return _dict
def dict_to_model(_dict: dict, model: TModel) -> TModel:
def dict_to_model(_dict: dict, model: type[TModel]) -> TModel:
"""
Convert a dictionary with JSON-encoded nested models to a Pydantic model
TODO: no recursion, maybe make them recursive?
"""
for key, value in _dict.items():
type_ = model.__fields__[key].type_
if type_ is BaseModel:
_dict[key] = json.loads(value)
if issubclass(type_, BaseModel):
_dict[key] = type_.construct(**json.loads(value))
return model.construct(**_dict)

View File

@ -4,7 +4,6 @@ from typing import Optional
from pydantic import BaseModel
from lnbits.db import FromRowModel
from lnbits.wallets import get_funding_source, set_funding_source
@ -12,23 +11,17 @@ class FakeError(Exception):
pass
class DbTestModel(FromRowModel):
class DbTestModel(BaseModel):
id: int
name: str
value: Optional[str] = None
class DbTestModelInner(BaseModel):
id: int
label: str
description: Optional[str] = None
class DbTestModel2(BaseModel):
id: int
name: str
value: Optional[str] = None
child: DbTestModelInner
label: str
description: Optional[str] = None
child: DbTestModel
def get_random_string(iterations: int = 10):

View File

@ -1,3 +1,5 @@
import json
import pytest
from lnbits.db import (
@ -6,13 +8,13 @@ from lnbits.db import (
model_to_dict,
update_query,
)
from tests.helpers import DbTestModel2, DbTestModelInner
from tests.helpers import DbTestModel, DbTestModel2
test_data = DbTestModel2(
id=1,
name="test",
value="myvalue",
child=DbTestModelInner(id=2, label="mylabel", description="mydesc"),
label="test",
description="mydesc",
child=DbTestModel(id=2, name="myname", value="myvalue"),
)
@ -20,8 +22,8 @@ test_data = DbTestModel2(
async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test_data)
assert (
q == "INSERT INTO test_helpers_query (id, name, value, child) "
"VALUES (:id, :name, :value, :child)"
q == "INSERT INTO test_helpers_query (id, label, description, child) "
"VALUES (:id, :label, :description, :child)"
)
@ -30,23 +32,24 @@ async def test_helpers_update_query():
q = update_query("test_helpers_query", test_data)
assert (
q == "UPDATE test_helpers_query "
"SET id = :id, name = :name, value = :value, child = :child "
"SET id = :id, label = :label, description = :description, child = :child "
"WHERE id = :id"
)
child_dict = json.dumps({"id": 2, "name": "myname", "value": "myvalue"})
test_dict = {"id": 1, "label": "test", "description": "mydesc", "child": child_dict}
@pytest.mark.asyncio
async def test_helpers_model_to_dict():
d = model_to_dict(test_data)
assert d == {
"id": 1,
"name": "test",
"value": "myvalue",
"child": {"id": 2, "label": "mylabel", "description": "mydesc"},
}
assert d == test_dict
@pytest.mark.asyncio
async def test_helpers_dict_to_model():
m = dict_to_model(model_to_dict(test_data), DbTestModel2)
m = dict_to_model(test_dict, DbTestModel2)
assert m == test_data
assert type(m) is DbTestModel2
assert type(m.child) is DbTestModel