diff --git a/lnbits/db.py b/lnbits/db.py index be120531b..752e2c509 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -11,7 +11,7 @@ from sqlite3 import Row from typing import Any, Generic, List, Literal, Optional, Type, TypeVar from loguru import logger -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ValidationError, root_validator from sqlalchemy import create_engine from sqlalchemy_aio.base import AsyncConnection from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY @@ -344,6 +344,7 @@ class FromRowModel(BaseModel): class FilterModel(BaseModel): __search_fields__: List[str] = [] + __sort_fields__: Optional[List[str]] = None T = TypeVar("T") @@ -427,6 +428,14 @@ class Filter(BaseModel, Generic[TFilterModel]): class Filters(BaseModel, Generic[TFilterModel]): + """ + Generic helper class for filtering and sorting data. + For usage in an api endpoint, use the `parse_filters` dependency. + + When constructing this class manually always make sure to pass a model so that the values can be validated. + Otherwise, make sure to validate the inputs manually. + """ + filters: List[Filter[TFilterModel]] = [] search: Optional[str] = None @@ -438,6 +447,18 @@ class Filters(BaseModel, Generic[TFilterModel]): model: Optional[Type[TFilterModel]] = None + @root_validator(pre=True) + def validate_sortby(cls, values): + sortby = values.get("sortby") + model = values.get("model") + if sortby and model: + model = values["model"] + # if no sort fields are specified explicitly all fields are allowed + allowed = model.__sort_fields__ or model.__fields__ + if sortby not in allowed: + raise ValueError("Invalid sort field") + return values + def pagination(self) -> str: stmt = "" if self.limit: