mirror of
https://github.com/lnbits/lnbits.git
synced 2025-04-05 18:38:14 +02:00
fixup!
This commit is contained in:
parent
185ed0202e
commit
7712d10d83
@ -13,7 +13,7 @@ from ecdsa import SECP256k1, SigningKey
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from lnbits.db import FilterModel, FromRowModel
|
||||
from lnbits.db import FilterModel
|
||||
from lnbits.helpers import url_for
|
||||
from lnbits.lnurl import encode as lnurl_encode
|
||||
from lnbits.settings import settings
|
||||
@ -103,7 +103,7 @@ class UserConfig(BaseModel):
|
||||
provider: Optional[str] = "lnbits" # auth provider
|
||||
|
||||
|
||||
class Account(FromRowModel):
|
||||
class Account(BaseModel):
|
||||
id: str
|
||||
is_super_user: Optional[bool] = False
|
||||
is_admin: Optional[bool] = False
|
||||
@ -244,7 +244,7 @@ class CreatePayment(BaseModel):
|
||||
fee: int = 0
|
||||
|
||||
|
||||
class Payment(FromRowModel):
|
||||
class Payment(BaseModel):
|
||||
status: str
|
||||
# TODO should be removed in the future, backward compatibility
|
||||
pending: bool
|
||||
|
48
lnbits/db.py
48
lnbits/db.py
@ -146,7 +146,10 @@ class Connection(Compat):
|
||||
return clean_values
|
||||
|
||||
async def fetchall(
|
||||
self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None
|
||||
self,
|
||||
query: str,
|
||||
values: Optional[dict] = None,
|
||||
model: Optional[type[TModel]] = None,
|
||||
) -> list[TModel]:
|
||||
params = self.rewrite_values(values) if values else {}
|
||||
result = await self.conn.execute(text(self.rewrite_query(query)), params)
|
||||
@ -159,7 +162,10 @@ class Connection(Compat):
|
||||
return row
|
||||
|
||||
async def fetchone(
|
||||
self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None
|
||||
self,
|
||||
query: str,
|
||||
values: Optional[dict] = None,
|
||||
model: Optional[type[TModel]] = None,
|
||||
) -> TModel:
|
||||
params = self.rewrite_values(values) if values else {}
|
||||
result = await self.conn.execute(text(self.rewrite_query(query)), params)
|
||||
@ -187,9 +193,9 @@ class Connection(Compat):
|
||||
where: Optional[list[str]] = None,
|
||||
values: Optional[dict] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[type[TRowModel]] = None,
|
||||
model: Optional[type[TModel]] = None,
|
||||
group_by: Optional[list[str]] = None,
|
||||
) -> Page[TRowModel]:
|
||||
) -> Page[TModel]:
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
clause = filters.where(where)
|
||||
@ -213,11 +219,12 @@ class Connection(Compat):
|
||||
{filters.pagination()}
|
||||
""",
|
||||
self.rewrite_values(parsed_values),
|
||||
model,
|
||||
)
|
||||
if rows:
|
||||
# no need for extra query if no pagination is specified
|
||||
if filters.offset or filters.limit:
|
||||
result = await self.fetchone(
|
||||
result = await self.execute(
|
||||
f"""
|
||||
SELECT COUNT(*) as count FROM (
|
||||
{query}
|
||||
@ -234,7 +241,7 @@ class Connection(Compat):
|
||||
count = 0
|
||||
|
||||
return Page(
|
||||
data=[model.from_row(row) for row in rows] if model else [],
|
||||
data=rows,
|
||||
total=count,
|
||||
)
|
||||
|
||||
@ -343,9 +350,9 @@ class Database(Compat):
|
||||
where: Optional[list[str]] = None,
|
||||
values: Optional[dict] = None,
|
||||
filters: Optional[Filters] = None,
|
||||
model: Optional[type[TRowModel]] = None,
|
||||
model: Optional[type[TModel]] = None,
|
||||
group_by: Optional[list[str]] = None,
|
||||
) -> Page[TRowModel]:
|
||||
) -> Page[TModel]:
|
||||
async with self.connect() as conn:
|
||||
return await conn.fetch_page(query, where, values, filters, model, group_by)
|
||||
|
||||
@ -405,12 +412,6 @@ class Operator(Enum):
|
||||
raise ValueError("Unknown SQL Operator")
|
||||
|
||||
|
||||
class FromRowModel(BaseModel):
|
||||
@classmethod
|
||||
def from_row(cls, row: dict):
|
||||
return cls(**row)
|
||||
|
||||
|
||||
class FilterModel(BaseModel):
|
||||
__search_fields__: list[str] = []
|
||||
__sort_fields__: Optional[list[str]] = None
|
||||
@ -418,7 +419,6 @@ class FilterModel(BaseModel):
|
||||
|
||||
T = TypeVar("T")
|
||||
TModel = TypeVar("TModel", bound=BaseModel)
|
||||
TRowModel = TypeVar("TRowModel", bound=FromRowModel)
|
||||
TFilterModel = TypeVar("TFilterModel", bound=FilterModel)
|
||||
|
||||
|
||||
@ -585,19 +585,27 @@ def update_query(
|
||||
|
||||
|
||||
def model_to_dict(model: BaseModel) -> dict:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary with JSON-encoded nested models
|
||||
TODO: no recursion, maybe make them recursive?
|
||||
"""
|
||||
_dict = model.dict()
|
||||
for key, value in _dict.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
type_ = model.__fields__[key].type_
|
||||
if type_ == BaseModel:
|
||||
_dict[key] = json.dumps(value.dict())
|
||||
if type(type_) is type(BaseModel):
|
||||
_dict[key] = json.dumps(value)
|
||||
return _dict
|
||||
|
||||
|
||||
def dict_to_model(_dict: dict, model: TModel) -> TModel:
|
||||
def dict_to_model(_dict: dict, model: type[TModel]) -> TModel:
|
||||
"""
|
||||
Convert a dictionary with JSON-encoded nested models to a Pydantic model
|
||||
TODO: no recursion, maybe make them recursive?
|
||||
"""
|
||||
for key, value in _dict.items():
|
||||
type_ = model.__fields__[key].type_
|
||||
if type_ is BaseModel:
|
||||
_dict[key] = json.loads(value)
|
||||
if issubclass(type_, BaseModel):
|
||||
_dict[key] = type_.construct(**json.loads(value))
|
||||
return model.construct(**_dict)
|
||||
|
@ -4,7 +4,6 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lnbits.db import FromRowModel
|
||||
from lnbits.wallets import get_funding_source, set_funding_source
|
||||
|
||||
|
||||
@ -12,23 +11,17 @@ class FakeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DbTestModel(FromRowModel):
|
||||
class DbTestModel(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
value: Optional[str] = None
|
||||
|
||||
|
||||
class DbTestModelInner(BaseModel):
|
||||
id: int
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class DbTestModel2(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
value: Optional[str] = None
|
||||
child: DbTestModelInner
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
child: DbTestModel
|
||||
|
||||
|
||||
def get_random_string(iterations: int = 10):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from lnbits.db import (
|
||||
@ -6,13 +8,13 @@ from lnbits.db import (
|
||||
model_to_dict,
|
||||
update_query,
|
||||
)
|
||||
from tests.helpers import DbTestModel2, DbTestModelInner
|
||||
from tests.helpers import DbTestModel, DbTestModel2
|
||||
|
||||
test_data = DbTestModel2(
|
||||
id=1,
|
||||
name="test",
|
||||
value="myvalue",
|
||||
child=DbTestModelInner(id=2, label="mylabel", description="mydesc"),
|
||||
label="test",
|
||||
description="mydesc",
|
||||
child=DbTestModel(id=2, name="myname", value="myvalue"),
|
||||
)
|
||||
|
||||
|
||||
@ -20,8 +22,8 @@ test_data = DbTestModel2(
|
||||
async def test_helpers_insert_query():
|
||||
q = insert_query("test_helpers_query", test_data)
|
||||
assert (
|
||||
q == "INSERT INTO test_helpers_query (id, name, value, child) "
|
||||
"VALUES (:id, :name, :value, :child)"
|
||||
q == "INSERT INTO test_helpers_query (id, label, description, child) "
|
||||
"VALUES (:id, :label, :description, :child)"
|
||||
)
|
||||
|
||||
|
||||
@ -30,23 +32,24 @@ async def test_helpers_update_query():
|
||||
q = update_query("test_helpers_query", test_data)
|
||||
assert (
|
||||
q == "UPDATE test_helpers_query "
|
||||
"SET id = :id, name = :name, value = :value, child = :child "
|
||||
"SET id = :id, label = :label, description = :description, child = :child "
|
||||
"WHERE id = :id"
|
||||
)
|
||||
|
||||
|
||||
child_dict = json.dumps({"id": 2, "name": "myname", "value": "myvalue"})
|
||||
test_dict = {"id": 1, "label": "test", "description": "mydesc", "child": child_dict}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_helpers_model_to_dict():
|
||||
d = model_to_dict(test_data)
|
||||
assert d == {
|
||||
"id": 1,
|
||||
"name": "test",
|
||||
"value": "myvalue",
|
||||
"child": {"id": 2, "label": "mylabel", "description": "mydesc"},
|
||||
}
|
||||
assert d == test_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_helpers_dict_to_model():
|
||||
m = dict_to_model(model_to_dict(test_data), DbTestModel2)
|
||||
m = dict_to_model(test_dict, DbTestModel2)
|
||||
assert m == test_data
|
||||
assert type(m) is DbTestModel2
|
||||
assert type(m.child) is DbTestModel
|
||||
|
Loading…
x
Reference in New Issue
Block a user