From 7712d10d832c4f887e5c416edbdcd0320ed87ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Thu, 26 Sep 2024 13:10:01 +0200 Subject: [PATCH] fixup! --- lnbits/core/models.py | 6 ++-- lnbits/db.py | 48 +++++++++++++++++++------------- tests/helpers.py | 15 +++------- tests/unit/test_helpers_query.py | 31 +++++++++++---------- 4 files changed, 52 insertions(+), 48 deletions(-) diff --git a/lnbits/core/models.py b/lnbits/core/models.py index d246aac54..916109554 100644 --- a/lnbits/core/models.py +++ b/lnbits/core/models.py @@ -13,7 +13,7 @@ from ecdsa import SECP256k1, SigningKey from fastapi import Query from pydantic import BaseModel, validator -from lnbits.db import FilterModel, FromRowModel +from lnbits.db import FilterModel from lnbits.helpers import url_for from lnbits.lnurl import encode as lnurl_encode from lnbits.settings import settings @@ -103,7 +103,7 @@ class UserConfig(BaseModel): provider: Optional[str] = "lnbits" # auth provider -class Account(FromRowModel): +class Account(BaseModel): id: str is_super_user: Optional[bool] = False is_admin: Optional[bool] = False @@ -244,7 +244,7 @@ class CreatePayment(BaseModel): fee: int = 0 -class Payment(FromRowModel): +class Payment(BaseModel): status: str # TODO should be removed in the future, backward compatibility pending: bool diff --git a/lnbits/db.py b/lnbits/db.py index 95c526722..d448aff52 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -146,7 +146,10 @@ class Connection(Compat): return clean_values async def fetchall( - self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None + self, + query: str, + values: Optional[dict] = None, + model: Optional[type[TModel]] = None, ) -> list[TModel]: params = self.rewrite_values(values) if values else {} result = await self.conn.execute(text(self.rewrite_query(query)), params) @@ -159,7 +162,10 @@ class Connection(Compat): return row async def fetchone( - self, query: str, values: Optional[dict] = None, model: Optional[TModel] = None + self, + query: str, + values: Optional[dict] = None, + model: Optional[type[TModel]] = None, ) -> TModel: params = self.rewrite_values(values) if values else {} result = await self.conn.execute(text(self.rewrite_query(query)), params) @@ -187,9 +193,9 @@ class Connection(Compat): where: Optional[list[str]] = None, values: Optional[dict] = None, filters: Optional[Filters] = None, - model: Optional[type[TRowModel]] = None, + model: Optional[type[TModel]] = None, group_by: Optional[list[str]] = None, - ) -> Page[TRowModel]: + ) -> Page[TModel]: if not filters: filters = Filters() clause = filters.where(where) @@ -213,11 +219,12 @@ class Connection(Compat): {filters.pagination()} """, self.rewrite_values(parsed_values), + model, ) if rows: # no need for extra query if no pagination is specified if filters.offset or filters.limit: - result = await self.fetchone( + result = await self.execute( f""" SELECT COUNT(*) as count FROM ( {query} @@ -234,7 +241,7 @@ class Connection(Compat): count = 0 return Page( - data=[model.from_row(row) for row in rows] if model else [], + data=rows, total=count, ) @@ -343,9 +350,9 @@ class Database(Compat): where: Optional[list[str]] = None, values: Optional[dict] = None, filters: Optional[Filters] = None, - model: Optional[type[TRowModel]] = None, + model: Optional[type[TModel]] = None, group_by: Optional[list[str]] = None, - ) -> Page[TRowModel]: + ) -> Page[TModel]: async with self.connect() as conn: return await conn.fetch_page(query, where, values, filters, model, group_by) @@ -405,12 +412,6 @@ class Operator(Enum): raise ValueError("Unknown SQL Operator") -class FromRowModel(BaseModel): - @classmethod - def from_row(cls, row: dict): - return cls(**row) - - class FilterModel(BaseModel): __search_fields__: list[str] = [] __sort_fields__: Optional[list[str]] = None @@ -418,7 +419,6 @@ class FilterModel(BaseModel): T = TypeVar("T") TModel = TypeVar("TModel", bound=BaseModel) -TRowModel = TypeVar("TRowModel", bound=FromRowModel) TFilterModel = TypeVar("TFilterModel", bound=FilterModel) @@ -585,19 +585,27 @@ def update_query( def model_to_dict(model: BaseModel) -> dict: + """ + Convert a Pydantic model to a dictionary with JSON-encoded nested models + TODO: no recursion, maybe make them recursive? + """ _dict = model.dict() for key, value in _dict.items(): if key.startswith("_"): continue type_ = model.__fields__[key].type_ - if type_ == BaseModel: - _dict[key] = json.dumps(value.dict()) + if type(type_) is type(BaseModel): + _dict[key] = json.dumps(value) return _dict -def dict_to_model(_dict: dict, model: TModel) -> TModel: +def dict_to_model(_dict: dict, model: type[TModel]) -> TModel: + """ + Convert a dictionary with JSON-encoded nested models to a Pydantic model + TODO: no recursion, maybe make them recursive? + """ for key, value in _dict.items(): type_ = model.__fields__[key].type_ - if type_ is BaseModel: - _dict[key] = json.loads(value) + if issubclass(type_, BaseModel): + _dict[key] = type_.construct(**json.loads(value)) return model.construct(**_dict) diff --git a/tests/helpers.py b/tests/helpers.py index 8f477b727..adf49ca00 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,7 +4,6 @@ from typing import Optional from pydantic import BaseModel -from lnbits.db import FromRowModel from lnbits.wallets import get_funding_source, set_funding_source @@ -12,23 +11,17 @@ class FakeError(Exception): pass -class DbTestModel(FromRowModel): +class DbTestModel(BaseModel): id: int name: str value: Optional[str] = None -class DbTestModelInner(BaseModel): - id: int - label: str - description: Optional[str] = None - - class DbTestModel2(BaseModel): id: int - name: str - value: Optional[str] = None - child: DbTestModelInner + label: str + description: Optional[str] = None + child: DbTestModel def get_random_string(iterations: int = 10): diff --git a/tests/unit/test_helpers_query.py b/tests/unit/test_helpers_query.py index c99a203d0..c932bb49d 100644 --- a/tests/unit/test_helpers_query.py +++ b/tests/unit/test_helpers_query.py @@ -1,3 +1,5 @@ +import json + import pytest from lnbits.db import ( @@ -6,13 +8,13 @@ from lnbits.db import ( model_to_dict, update_query, ) -from tests.helpers import DbTestModel2, DbTestModelInner +from tests.helpers import DbTestModel, DbTestModel2 test_data = DbTestModel2( id=1, - name="test", - value="myvalue", - child=DbTestModelInner(id=2, label="mylabel", description="mydesc"), + label="test", + description="mydesc", + child=DbTestModel(id=2, name="myname", value="myvalue"), ) @@ -20,8 +22,8 @@ test_data = DbTestModel2( async def test_helpers_insert_query(): q = insert_query("test_helpers_query", test_data) assert ( - q == "INSERT INTO test_helpers_query (id, name, value, child) " - "VALUES (:id, :name, :value, :child)" + q == "INSERT INTO test_helpers_query (id, label, description, child) " + "VALUES (:id, :label, :description, :child)" ) @@ -30,23 +32,24 @@ async def test_helpers_update_query(): q = update_query("test_helpers_query", test_data) assert ( q == "UPDATE test_helpers_query " - "SET id = :id, name = :name, value = :value, child = :child " + "SET id = :id, label = :label, description = :description, child = :child " "WHERE id = :id" ) +child_dict = json.dumps({"id": 2, "name": "myname", "value": "myvalue"}) +test_dict = {"id": 1, "label": "test", "description": "mydesc", "child": child_dict} + + @pytest.mark.asyncio async def test_helpers_model_to_dict(): d = model_to_dict(test_data) - assert d == { - "id": 1, - "name": "test", - "value": "myvalue", - "child": {"id": 2, "label": "mylabel", "description": "mydesc"}, - } + assert d == test_dict @pytest.mark.asyncio async def test_helpers_dict_to_model(): - m = dict_to_model(model_to_dict(test_data), DbTestModel2) + m = dict_to_model(test_dict, DbTestModel2) assert m == test_data + assert type(m) is DbTestModel2 + assert type(m.child) is DbTestModel