remove nested filters (#1843)

* remove nested filters

* remove from openapi schema aswell
This commit is contained in:
jackstar12 2023-07-31 10:21:30 +02:00 committed by GitHub
parent 7f0c7138af
commit 67d3a4f359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 40 deletions

View File

@ -375,7 +375,6 @@ class Page(BaseModel, Generic[T]):
class Filter(BaseModel, Generic[TFilterModel]): class Filter(BaseModel, Generic[TFilterModel]):
field: str field: str
nested: Optional[List[str]]
op: Operator = Operator.EQ op: Operator = Operator.EQ
values: list[Any] values: list[Any]
@ -390,55 +389,36 @@ class Filter(BaseModel, Generic[TFilterModel]):
split = key[:-1].split("[") split = key[:-1].split("[")
if len(split) != 2: if len(split) != 2:
raise ValueError("Invalid key") raise ValueError("Invalid key")
field_names = split[0].split(".") field = split[0]
op = Operator(split[1]) op = Operator(split[1])
else: else:
field_names = key.split(".") field = key
op = Operator("eq") op = Operator("eq")
field = field_names[0]
nested = field_names[1:]
if field in model.__fields__: if field in model.__fields__:
compare_field = model.__fields__[field] compare_field = model.__fields__[field]
values = [] values = []
for raw_value in raw_values: for raw_value in raw_values:
# If there is a nested field, pydantic expects a dict, so the raw value is turned into a dict before
# and the converted value is extracted afterwards
for name in reversed(nested):
raw_value = {name: raw_value}
validated, errors = compare_field.validate(raw_value, {}, loc="none") validated, errors = compare_field.validate(raw_value, {}, loc="none")
if errors: if errors:
raise ValidationError(errors=[errors], model=model) raise ValidationError(errors=[errors], model=model)
for name in nested:
if isinstance(validated, dict):
validated = validated[name]
else:
validated = getattr(validated, name)
values.append(validated) values.append(validated)
else: else:
raise ValueError("Unknown filter field") raise ValueError("Unknown filter field")
return cls(field=field, op=op, nested=nested, values=values, model=model) return cls(field=field, op=op, values=values, model=model)
@property @property
def statement(self): def statement(self):
accessor = self.field
if self.nested:
for name in self.nested:
accessor = f"({accessor} ->> '{name}')"
if self.model and self.model.__fields__[self.field].type_ == datetime.datetime: if self.model and self.model.__fields__[self.field].type_ == datetime.datetime:
placeholder = Compat.timestamp_placeholder placeholder = Compat.timestamp_placeholder
else: else:
placeholder = "?" placeholder = "?"
if self.op in (Operator.INCLUDE, Operator.EXCLUDE): if self.op in (Operator.INCLUDE, Operator.EXCLUDE):
placeholders = ", ".join([placeholder] * len(self.values)) placeholders = ", ".join([placeholder] * len(self.values))
stmt = [f"{accessor} {self.op.as_sql} ({placeholders})"] stmt = [f"{self.field} {self.op.as_sql} ({placeholders})"]
else: else:
stmt = [f"{accessor} {self.op.as_sql} {placeholder}"] * len(self.values) stmt = [f"{self.field} {self.op.as_sql} {placeholder}"] * len(self.values)
return " OR ".join(stmt) return " OR ".join(stmt)

View File

@ -4,11 +4,7 @@ from typing import Any, List, Optional, Type
import jinja2 import jinja2
import shortuuid import shortuuid
from pydantic.schema import ( from pydantic.schema import field_schema
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from lnbits.jinja2_templating import Jinja2Templates from lnbits.jinja2_templating import Jinja2Templates
from lnbits.requestvars import g from lnbits.requestvars import g
@ -102,20 +98,11 @@ def generate_filter_params_openapi(model: Type[FilterModel], keep_optional=False
:param keep_optional: If false, all parameters will be optional, otherwise inferred from model :param keep_optional: If false, all parameters will be optional, otherwise inferred from model
""" """
fields = list(model.__fields__.values()) fields = list(model.__fields__.values())
models = get_flat_models_from_fields(fields, set())
namemap = get_model_name_map(models)
params = [] params = []
for field in fields: for field in fields:
schema, definitions, _ = field_schema(field, model_name_map=namemap) schema, _, _ = field_schema(field, model_name_map={})
# Support nested definition
if "$ref" in schema:
name = schema["$ref"].split("/")[-1]
schema = definitions[name]
description = "Supports Filtering" description = "Supports Filtering"
if schema["type"] == "object":
description += f". Nested attributes can be filtered too, e.g. `{field.alias}.[additional].[attributes]`"
if ( if (
hasattr(model, "__search_fields__") hasattr(model, "__search_fields__")
and field.name in model.__search_fields__ and field.name in model.__search_fields__