feat: specify sort fields (#1716)

* feat: specify sort fields

* doc string
This commit is contained in:
jackstar12
2023-05-22 13:11:05 +02:00
committed by GitHub
parent 37ac630573
commit 252fd6c313

View File

@@ -11,7 +11,7 @@ from sqlite3 import Row
from typing import Any, Generic, List, Literal, Optional, Type, TypeVar from typing import Any, Generic, List, Literal, Optional, Type, TypeVar
from loguru import logger from loguru import logger
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError, root_validator
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy_aio.base import AsyncConnection from sqlalchemy_aio.base import AsyncConnection
from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY from sqlalchemy_aio.strategy import ASYNCIO_STRATEGY
@@ -344,6 +344,7 @@ class FromRowModel(BaseModel):
class FilterModel(BaseModel): class FilterModel(BaseModel):
__search_fields__: List[str] = [] __search_fields__: List[str] = []
__sort_fields__: Optional[List[str]] = None
T = TypeVar("T") T = TypeVar("T")
@@ -427,6 +428,14 @@ class Filter(BaseModel, Generic[TFilterModel]):
class Filters(BaseModel, Generic[TFilterModel]): class Filters(BaseModel, Generic[TFilterModel]):
"""
Generic helper class for filtering and sorting data.
For usage in an api endpoint, use the `parse_filters` dependency.
When constructing this class manually always make sure to pass a model so that the values can be validated.
Otherwise, make sure to validate the inputs manually.
"""
filters: List[Filter[TFilterModel]] = [] filters: List[Filter[TFilterModel]] = []
search: Optional[str] = None search: Optional[str] = None
@@ -438,6 +447,18 @@ class Filters(BaseModel, Generic[TFilterModel]):
model: Optional[Type[TFilterModel]] = None model: Optional[Type[TFilterModel]] = None
@root_validator(pre=True)
def validate_sortby(cls, values):
sortby = values.get("sortby")
model = values.get("model")
if sortby and model:
model = values["model"]
# if no sort fields are specified explicitly all fields are allowed
allowed = model.__sort_fields__ or model.__fields__
if sortby not in allowed:
raise ValueError("Invalid sort field")
return values
def pagination(self) -> str: def pagination(self) -> str:
stmt = "" stmt = ""
if self.limit: if self.limit: