diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py deleted file mode 100644 index 07dc2824..00000000 --- a/aredis_om/_compat.py +++ /dev/null @@ -1,99 +0,0 @@ -from dataclasses import dataclass, is_dataclass -from typing import ( - Any, - Callable, - Deque, - Dict, - FrozenSet, - List, - Mapping, - Sequence, - Set, - Tuple, - Type, - Union, -) - -from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Annotated, Literal, get_args, get_origin - - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -if PYDANTIC_V2: - - def use_pydantic_2_plus(): - return True - - from pydantic import BaseModel, TypeAdapter - from pydantic import ValidationError as ValidationError - from pydantic import validator - from pydantic._internal._model_construction import ModelMetaclass - from pydantic._internal._repr import Representation - from pydantic.deprecated.json import ENCODERS_BY_TYPE - from pydantic.fields import FieldInfo - from pydantic.v1.main import validate_model - from pydantic.v1.typing import NoArgAnyCallable - from pydantic_core import PydanticUndefined as Undefined - from pydantic_core import PydanticUndefinedType as UndefinedType - - @dataclass - class ModelField: - field_info: FieldInfo - name: str - mode: Literal["validation", "serialization"] = "validation" - - @property - def alias(self) -> str: - a = self.field_info.alias - return a if a is not None else self.name - - @property - def required(self) -> bool: - return self.field_info.is_required() - - @property - def default(self) -> Any: - return self.get_default() - - @property - def type_(self) -> Any: - return self.field_info.annotation - - def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info] - ) - - def get_default(self) -> Any: - if self.field_info.is_required(): - return Undefined - return self.field_info.get_default(call_default_factory=True) - - def validate( - self, - value: Any, - values: Dict[str, Any] = {}, # noqa: B006 - *, - loc: Tuple[Union[int, str], ...] = (), - ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: - return ( - self._type_adapter.validate_python(value, from_attributes=True), - None, - ) - - def __hash__(self) -> int: - # Each ModelField is unique for our purposes, to allow making a dict from - # ModelField to its JSON Schema. - return id(self) - -else: - from pydantic import BaseModel, validator - from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType - from pydantic.json import ENCODERS_BY_TYPE - from pydantic.main import ModelMetaclass, validate_model - from pydantic.typing import NoArgAnyCallable - from pydantic.utils import Representation - - def use_pydantic_2_plus(): - return False diff --git a/aredis_om/checks.py b/aredis_om/checks.py index be2332cf..522f5c38 100644 --- a/aredis_om/checks.py +++ b/aredis_om/checks.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import List from aredis_om.connections import get_redis_connection diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index cb71447c..236133e7 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -31,7 +31,9 @@ from types import GeneratorType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union -from .._compat import ENCODERS_BY_TYPE, BaseModel +from pydantic import BaseModel +from pydantic.deprecated.json import ENCODERS_BY_TYPE +from pydantic_core import PydanticUndefined SetIntStr = Set[Union[int, str]] @@ -72,7 +74,7 @@ def jsonable_encoder( encoder = getattr(obj.__config__, "json_encoders", {}) if custom_encoder: encoder.update(custom_encoder) - obj_dict = obj.dict( + obj_dict = obj.model_dump( include=include, # type: ignore # in Pydantic exclude=exclude, # type: ignore # in Pydantic by_alias=by_alias, @@ -106,6 +108,7 @@ def jsonable_encoder( or (not isinstance(key, str)) or (not key.startswith("_sa")) ) + and value is not PydanticUndefined and (value is not None or not exclude_none) and ((include and key in include) or not exclude or key not in exclude) ): diff --git a/aredis_om/model/migrations/migrator.py b/aredis_om/model/migrations/migrator.py index eb3d1b0b..34aa7c14 100644 --- a/aredis_om/model/migrations/migrator.py +++ b/aredis_om/model/migrations/migrator.py @@ -131,7 +131,7 @@ async def detect_migrations(self): stored_hash = await conn.get(hash_key) if isinstance(stored_hash, bytes): - stored_hash = stored_hash.decode('utf-8') + stored_hash = stored_hash.decode("utf-8") schema_out_of_date = current_hash != stored_hash diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 67769cd3..45333eb4 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -8,10 +8,8 @@ from enum import Enum from functools import reduce from typing import ( - AbstractSet, Any, Callable, - ClassVar, Dict, List, Literal, @@ -28,24 +26,19 @@ from typing import no_type_check from more_itertools import ichunked +from pydantic import BaseModel, ConfigDict, TypeAdapter, field_validator +from pydantic._internal._model_construction import ModelMetaclass +from pydantic._internal._repr import Representation +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import _FieldInfoInputs +from pydantic_core import PydanticUndefined as Undefined +from pydantic_core import PydanticUndefinedType as UndefinedType from redis.commands.json.path import Path from redis.exceptions import ResponseError -from typing_extensions import Protocol, get_args, get_origin +from typing_extensions import Protocol, Unpack, get_args, get_origin from ulid import ULID from .. import redis -from .._compat import PYDANTIC_V2, BaseModel -from .._compat import FieldInfo as PydanticFieldInfo -from .._compat import ( - ModelField, - ModelMetaclass, - NoArgAnyCallable, - Representation, - Undefined, - UndefinedType, - validate_model, - validator, -) from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection from ..util import ASYNC_MODE @@ -78,7 +71,7 @@ ERRORS_URL = "https://github.com/redis/redis-om-python/blob/main/docs/errors.md" -def get_outer_type(field): +def get_outer_type(field: PydanticFieldInfo): if hasattr(field, "outer_type_"): return field.outer_type_ elif isinstance(field.annotation, type) or is_supported_container_type( @@ -88,7 +81,7 @@ def get_outer_type(field): elif not hasattr(field.annotation, "__args__"): return None else: - return field.annotation.__args__[0] + return field.annotation.__args__[0] # type: ignore class RedisModelError(Exception): @@ -127,9 +120,7 @@ def __str__(self): return str(self.name) -ExpressionOrModelField = Union[ - "Expression", "NegatedExpression", ModelField, PydanticFieldInfo -] +ExpressionOrModelField = Union["Expression", "NegatedExpression", PydanticFieldInfo] def embedded(cls): @@ -164,7 +155,7 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any obj = getattr(obj, sub_field) return - if field_name not in model.__fields__: # type: ignore + if field_name not in model.model_fields: # type: ignore raise QuerySyntaxError( f"The field {field_name} does not exist on the model {model.__name__}" ) @@ -288,26 +279,31 @@ def tree(self): @dataclasses.dataclass class KNNExpression: k: int - vector_field: ModelField + vector_field: "ExpressionProxy" + score_field: "ExpressionProxy" reference_vector: bytes def __str__(self): - return f"KNN $K @{self.vector_field.name} $knn_ref_vector" + return f"KNN $K @{self.vector_field_name} $knn_ref_vector AS {self.score_field_name}" @property def query_params(self) -> Dict[str, Union[str, bytes]]: return {"K": str(self.k), "knn_ref_vector": self.reference_vector} @property - def score_field(self) -> str: - return f"__{self.vector_field.name}_score" + def score_field_name(self) -> str: + return self.score_field.field.name + + @property + def vector_field_name(self) -> str: + return self.vector_field.field.name ExpressionOrNegated = Union[Expression, NegatedExpression] class ExpressionProxy: - def __init__(self, field: ModelField, parents: List[Tuple[str, "RedisModel"]]): + def __init__(self, field: "FieldInfo", parents: List[Tuple[str, "RedisModel"]]): self.field = field self.parents = parents.copy() # Ensure a copy is stored @@ -391,7 +387,7 @@ def __getattr__(self, item): if isinstance(attr, self.__class__): # Clone the parents to ensure isolation new_parents = self.parents.copy() - new_parent = (self.field.alias, outer_type) + new_parent = (self.field.name, outer_type) if new_parent not in new_parents: new_parents.append(new_parent) attr.parents = new_parents @@ -444,7 +440,7 @@ def __init__( if sort_fields: self.sort_fields = self.validate_sort_fields(sort_fields) elif self.knn: - self.sort_fields = [self.knn.score_field] + self.sort_fields = [self.knn.score_field_name] else: self.sort_fields = [] @@ -517,45 +513,36 @@ def query_params(self): def validate_sort_fields(self, sort_fields: List[str]): for sort_field in sort_fields: field_name = sort_field.lstrip("-") - if self.knn and field_name == self.knn.score_field: + if self.knn and field_name == self.knn.score_field_name: continue - if field_name not in self.model.__fields__: # type: ignore + if field_name not in self.model.model_fields: # type: ignore raise QueryNotSupportedError( f"You tried sort by {field_name}, but that field " f"does not exist on the model {self.model}" ) - field_proxy = getattr(self.model, field_name) - if isinstance(field_proxy.field, FieldInfo) or isinstance( - field_proxy.field, PydanticFieldInfo - ): - field_info = field_proxy.field - else: - field_info = field_proxy.field.field_info + field_proxy: ExpressionProxy = getattr(self.model, field_name) - if not getattr(field_info, "sortable", False): + if ( + field_proxy.field.sortable is not True + and field_proxy.field.index is not True + ): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " - f"not define that field as sortable. Docs: {ERRORS_URL}#E2" + f"not define that field as sortable or indexed. Docs: {ERRORS_URL}#E2" ) return sort_fields @staticmethod - def resolve_field_type( - field: Union[ModelField, PydanticFieldInfo], op: Operators - ) -> RediSearchFieldTypes: - field_info: Union[FieldInfo, ModelField, PydanticFieldInfo] + def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes: + field_info: Union[FieldInfo, PydanticFieldInfo] = field - if not hasattr(field, "field_info"): - field_info = field - else: - field_info = field.field_info if getattr(field_info, "primary_key", None) is True: return RediSearchFieldTypes.TAG elif op is Operators.LIKE: fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError( - f"You tried to do a full-text search on the field '{field.alias}', " + f"You tried to do a full-text search on the field '{field.name}', " f"but the field is not indexed for full-text search. Use the " f"full_text_search=True option. Docs: {ERRORS_URL}#E3" ) @@ -803,18 +790,9 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: expression.left, NegatedExpression ): result += f"({cls.resolve_redisearch_query(expression.left)})" - elif isinstance(expression.left, ModelField): - field_type = cls.resolve_field_type(expression.left, expression.op) - field_name = expression.left.name - field_info = expression.left.field_info - if not field_info or not getattr(field_info, "index", None): - raise QueryNotSupportedError( - f"You tried to query by a field ({field_name}) " - f"that isn't indexed. Docs: {ERRORS_URL}#E6" - ) elif isinstance(expression.left, FieldInfo): field_type = cls.resolve_field_type(expression.left, expression.op) - field_name = expression.left.alias + field_name = expression.left.name field_info = expression.left if not field_info or not getattr(field_info, "index", None): raise QueryNotSupportedError( @@ -827,11 +805,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: f"or an expression enclosed in parentheses. Docs: {ERRORS_URL}#E7" ) - if isinstance(expression.left, ModelField) and expression.parents: - # Build field_name using the specific parents for this expression - prefix = "_".join([p[0] for p in expression.parents]) - field_name = f"{prefix}_{field_name}" - right = expression.right if isinstance(right, Expression) or isinstance(right, NegatedExpression): @@ -857,10 +830,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: raise QuerySyntaxError("Could not resolve field type. See docs: TODO") elif not field_info: raise QuerySyntaxError("Could not resolve field info. See docs: TODO") - elif isinstance(right, ModelField): - raise QueryNotSupportedError( - "Comparing fields is not supported. See docs: TODO" - ) else: result += cls.resolve_value( field_name, @@ -1089,6 +1058,8 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): + name: str + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) sortable = kwargs.pop("sortable", Undefined) @@ -1205,66 +1176,22 @@ def schema(self): def Field( - default: Any = Undefined, - *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - exclude: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - include: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, case_sensitive: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, - schema_extra: Optional[Dict[str, Any]] = None, + **kwargs: Unpack[_FieldInfoInputs], ) -> Any: - current_schema_extra = schema_extra or {} field_info = FieldInfo( - default, - default_factory=default_factory, - alias=alias, - title=title, - description=description, - exclude=exclude, - include=include, - const=const, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, - min_length=min_length, - max_length=max_length, - allow_mutation=allow_mutation, - regex=regex, + **kwargs, primary_key=primary_key, sortable=sortable, case_sensitive=case_sensitive, index=index, full_text_search=full_text_search, vector_options=vector_options, - **current_schema_extra, ) return field_info @@ -1272,7 +1199,11 @@ def Field( @dataclasses.dataclass class PrimaryKey: name: str - field: ModelField + field: PydanticFieldInfo + + +class RedisOmConfig(ConfigDict): + index: Optional[bool] class BaseMeta(Protocol): @@ -1309,9 +1240,30 @@ class DefaultMeta: class ModelMeta(ModelMetaclass): _meta: BaseMeta + model_config: RedisOmConfig + model_fields: Dict[str, FieldInfo] # type: ignore[assignment] + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop("Meta", None) - new_class = super().__new__(cls, name, bases, attrs, **kwargs) + + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: Set[str] = { + key + for key in dir(ConfigDict) + if not ( + key.startswith("__") and key.endswith("__") + ) # skip dunder methods and attributes + } + + config_kwargs = { + key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs + } + + new_class: RedisModel = super().__new__( + cls, name, bases, attrs, **config_kwargs + ) # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while @@ -1320,13 +1272,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = meta or getattr(new_class, "Meta", None) base_meta = getattr(new_class, "_meta", None) - if len(bases) >= 1: - for base_index in range(len(bases)): - model_fields = getattr(bases[base_index], "model_fields", []) - for f_name in model_fields: - field = model_fields[f_name] - new_class.model_fields[f_name] = field - if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -1345,48 +1290,30 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta + if new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) - for field_name, field in new_class.__fields__.items(): - if not isinstance(field, FieldInfo): - for base_candidate in bases: - if hasattr(base_candidate, field_name): - inner_field = getattr(base_candidate, field_name) - if hasattr(inner_field, "field") and isinstance( - getattr(inner_field, "field"), FieldInfo - ): - field.metadata.append(getattr(inner_field, "field")) - field = getattr(inner_field, "field") - - if not field.alias: - field.alias = field_name - setattr(new_class, field_name, ExpressionProxy(field, [])) - annotation = new_class.get_annotations().get(field_name) - if annotation: - new_class.__annotations__[field_name] = Union[ - annotation, ExpressionProxy - ] - else: - new_class.__annotations__[field_name] = ExpressionProxy - # Check if this is our FieldInfo version with extended ORM metadata. - field_info = None - if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo): - field_info = field.field_info - elif field_name in attrs and isinstance( - attrs.__getitem__(field_name), FieldInfo - ): - field_info = attrs.__getitem__(field_name) - field.field_info = field_info + # Only set if the model is has index=True + is_indexed = kwargs.get("index", None) is True + new_class.model_config["index"] = is_indexed - if field_info is not None: - if field_info.primary_key: - new_class._meta.primary_key = PrimaryKey( - name=field_name, field=field - ) - if field_info.vector_options: - score_attr = f"_{field_name}_score" - setattr(new_class, score_attr, None) - new_class.__annotations__[score_attr] = Union[float, None] + for field_name, field in new_class.model_fields.items(): + if field.__class__ is PydanticFieldInfo: + field = FieldInfo(**field._attributes_set) + setattr(new_class, field_name, field) + + if is_indexed: + setattr(new_class, field_name, ExpressionProxy(field, [])) + + # we need to set the field name for use in queries + field.name = field_name + + if field.primary_key is True: + new_class._meta.primary_key = PrimaryKey(name=field_name, field=field) if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( @@ -1418,16 +1345,20 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 f"{new_class._meta.model_key_prefix}:index" ) - # Not an abstract model class or embedded model, so we should let the + # Model is indexed and not an abstract model class or embedded model, so we should let the # Migrator create indexes for it. - if abc.ABC not in bases and not getattr(new_class._meta, "embedded", False): + if ( + abc.ABC not in bases + and not getattr(new_class._meta, "embedded", False) + and new_class.model_config.get("index") is True + ): key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class return new_class -def outer_type_or_annotation(field): +def outer_type_or_annotation(field: FieldInfo): if hasattr(field, "outer_type_"): return field.outer_type_ elif not hasattr(field.annotation, "__args__"): @@ -1437,41 +1368,18 @@ def outer_type_or_annotation(field): elif get_origin(field.annotation) == Literal: return str else: - return field.annotation.__args__[0] + return field.annotation.__args__[0] # type: ignore class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): - pk: Optional[str] = Field(default=None, primary_key=True) - if PYDANTIC_V2: - ConfigDict: ClassVar - + pk: Optional[str] = Field(default=None, primary_key=True, validate_default=True) Meta = DefaultMeta - if PYDANTIC_V2: - from pydantic import ConfigDict - - model_config = ConfigDict( - from_attributes=True, arbitrary_types_allowed=True, extra="allow" - ) - else: - - class Config: - orm_mode = True - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: __pydantic_self__.validate_primary_key() - missing_fields = __pydantic_self__.model_fields.keys() - data.keys() - {"pk"} - - kwargs = data.copy() - - # This is a hack, we need to manually make sure we are setting up defaults correctly when we encounter them - # because inheritance apparently won't cover that in pydantic 2.0. - for field in missing_fields: - default_value = __pydantic_self__.model_fields.get(field).default # type: ignore - kwargs[field] = default_value - super().__init__(**kwargs) + super().__init__(**data) def __lt__(self, other): """Default sort: compare primary key of models.""" @@ -1479,6 +1387,12 @@ def __lt__(self, other): def key(self): """Return the Redis key for this model.""" + if self.model_config.get("index", False) is not True: + raise RedisModelError( + "You cannot create a key on a model that is not indexed. " + f"Update your class with index=True: class {self.__class__.__name__}(RedisModel, index=True):" + ) + if hasattr(self._meta.primary_key.field, "name"): pk = getattr(self, self._meta.primary_key.field.name) else: @@ -1519,7 +1433,7 @@ async def expire( # TODO: Wrap any Redis response errors in a custom exception? await db.expire(self.key(), num_seconds) - @validator("pk", always=True, allow_reuse=True) + @field_validator("pk", mode="after") def validate_pk(cls, v): if not v or isinstance(v, ExpressionProxy): v = cls._meta.primary_key_creator_cls().create_pk() @@ -1529,26 +1443,23 @@ def validate_pk(cls, v): def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 - for name, field in cls.__fields__.items(): - if not hasattr(field, "field_info"): - if ( - not isinstance(field, FieldInfo) - and hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field_info = field.metadata[0] - else: - field_info = field + for name, field in cls.model_fields.items(): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] else: - field_info = field.field_info + field_info = field if getattr(field_info, "primary_key", None) is True: primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") elif primary_keys == 2: - cls.__fields__.pop("pk") + cls.model_fields.pop("pk") elif primary_keys > 2: raise RedisModelError("You must define only one primary key for a model") @@ -1602,10 +1513,8 @@ def to_string(s): # $ means a json entry if fields.get("$"): json_fields = json.loads(fields.pop("$")) + json_fields.update(fields) doc = cls(**json_fields) - for k, v in fields.items(): - if k.startswith("__") and k.endswith("_score"): - setattr(doc, k[1:], float(v)) else: doc = cls(**fields) @@ -1674,16 +1583,8 @@ def redisearch_schema(cls): raise NotImplementedError def check(self): - """Run all validations.""" - if not PYDANTIC_V2: - *_, validation_error = validate_model(self.__class__, self.__dict__) - if validation_error: - raise validation_error - else: - from pydantic import TypeAdapter - - adapter = TypeAdapter(self.__class__) - adapter.validate_python(self.__dict__) + adapter = TypeAdapter(self.__class__) + adapter.validate_python(self.__dict__) class HashModel(RedisModel, abc.ABC): @@ -1710,7 +1611,7 @@ def __init_subclass__(cls, **kwargs): f"HashModels cannot index dataclass fields. Field: {name}" ) - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): outer_type = outer_type_or_annotation(field) origin = get_origin(outer_type) if origin: @@ -1735,7 +1636,7 @@ async def save( ) -> "Model": self.check() db = self._get_db(pipeline) - document = jsonable_encoder(self.dict()) + document = jsonable_encoder(self.model_dump()) # filter out values which are `None` because they are not valid in a HSET document = {k: v for k, v in document.items() if v is not None} @@ -1763,7 +1664,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": if not document: raise NotFoundError try: - result = cls.parse_obj(document) + result = cls.model_validate(document) except TypeError as e: log.warning( f'Could not parse Redis response. Error was: "{e}". Probably, the ' @@ -1772,7 +1673,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": f"model class ({cls.__class__}. Encoding: {cls.Meta.encoding}." ) document = decode_redis_value(document, cls.Meta.encoding) - result = cls.parse_obj(document) + result = cls.model_validate(document) return result @classmethod @@ -1806,7 +1707,7 @@ async def update(self, **field_values): def schema_for_fields(cls): schema_parts = [] - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): # TODO: Merge this code with schema_for_type()? _type = outer_type_or_annotation(field) is_subscripted_type = get_origin(_type) @@ -1819,10 +1720,7 @@ def schema_for_fields(cls): ): field = field.metadata[0] - if not hasattr(field, "field_info"): - field_info = field - else: - field_info = field.field_info + field_info = field if getattr(field_info, "primary_key", None) is True: if issubclass(_type, str): @@ -1893,7 +1791,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" elif issubclass(typ, RedisModel): sub_fields = [] - for embedded_name, field in typ.__fields__.items(): + for embedded_name, field in typ.model_fields.items(): sub_fields.append( cls.schema_for_type( f"{name}_{embedded_name}", field.outer_type_, field.field_info @@ -1930,7 +1828,7 @@ async def save( db = self._get_db(pipeline) # TODO: Wrap response errors in a custom exception? - await db.json().set(self.key(), Path.root_path(), json.loads(self.json())) + await db.json().set(self.key(), Path.root_path(), self.model_dump(mode="json")) return self @classmethod @@ -1976,7 +1874,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": document = json.dumps(await cls.db().json().get(cls.make_key(pk))) if document == "null": raise NotFoundError - return cls.parse_raw(document) + return cls.model_validate_json(document) @classmethod def redisearch_schema(cls): @@ -1990,7 +1888,7 @@ def schema_for_fields(cls): schema_parts = [] json_path = "$" fields = dict() - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): fields[name] = field for name, field in cls.__dict__.items(): if isinstance(field, FieldInfo): @@ -2015,10 +1913,8 @@ def schema_for_fields(cls): ): field = field.metadata[0] - if hasattr(field, "field_info"): - field_info = field.field_info - else: - field_info = field + field_info = field + if getattr(field_info, "primary_key", None) is True: if issubclass(_type, str): redisearch_field = f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" @@ -2039,7 +1935,7 @@ def schema_for_type( json_path: str, name: str, name_prefix: str, - typ: Any, + typ: Union[type[RedisModel], Any], field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: @@ -2119,10 +2015,8 @@ def schema_for_type( elif field_is_model: name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] - for embedded_name, field in typ.__fields__.items(): - if hasattr(field, "field_info"): - field_info = field.field_info - elif ( + for embedded_name, field in typ.model_fields.items(): + if ( hasattr(field, "metadata") and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo) diff --git a/pyproject.toml b/pyproject.toml index aac11ba9..de8e622f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "0.3.5" +version = "0.4.0" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -37,7 +37,7 @@ include=[ [tool.poetry.dependencies] python = ">=3.8,<4.0" redis = ">=3.5.3,<6.0.0" -pydantic = ">=1.10.2,<3.0.0" +pydantic = ">=2.0.0,<3.0.0" click = "^8.0.1" types-redis = ">=3.5.9,<5.0.0" python-ulid = "^1.0.3" diff --git a/tests/_compat.py b/tests/_compat.py index 1cd55bf2..9b8f2856 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,10 +1 @@ -from aredis_om._compat import PYDANTIC_V2, use_pydantic_2_plus - - -if not use_pydantic_2_plus() and PYDANTIC_V2: - from pydantic.v1 import EmailStr, ValidationError -elif PYDANTIC_V2: - from pydantic import EmailStr, PositiveInt, ValidationError - -else: - from pydantic import EmailStr, ValidationError +from pydantic import EmailStr, PositiveInt, ValidationError diff --git a/tests/test_find_query.py b/tests/test_find_query.py index ecd14e4b..624f2ebd 100644 --- a/tests/test_find_query.py +++ b/tests/test_find_query.py @@ -51,7 +51,7 @@ class Note(EmbeddedJsonModel): description: str = Field(index=True) created_on: datetime.datetime - class Address(EmbeddedJsonModel): + class Address(EmbeddedJsonModel, index=True): address_line_1: str address_line_2: Optional[str] = None city: str = Field(index=True) @@ -60,15 +60,15 @@ class Address(EmbeddedJsonModel): postal_code: str = Field(index=True) note: Optional[Note] = None - class Item(EmbeddedJsonModel): + class Item(EmbeddedJsonModel, index=True): price: decimal.Decimal name: str = Field(index=True) - class Order(EmbeddedJsonModel): + class Order(EmbeddedJsonModel, index=True): items: List[Item] created_on: datetime.datetime - class Member(BaseJsonModel): + class Member(BaseJsonModel, index=True): first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 12df8cda..cf4ffe32 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -42,12 +42,12 @@ class BaseHashModel(HashModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Order(BaseHashModel): + class Order(BaseHashModel, index=True): total: decimal.Decimal currency: str created_on: datetime.datetime - class Member(BaseHashModel): + class Member(BaseHashModel, index=True): id: int = Field(index=True, primary_key=True) first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) @@ -177,7 +177,6 @@ async def test_full_text_search_queries(members, m): @py_test_mark_asyncio -@pytest.mark.xfail(strict=False) async def test_pagination_queries(members, m): member1, member2, member3 = members @@ -524,7 +523,7 @@ async def test_all_pks(m): @py_test_mark_asyncio async def test_all_pks_with_complex_pks(key_prefix): - class City(HashModel): + class City(HashModel, index=True): name: str class Meta: @@ -797,7 +796,7 @@ class Customer2(m.BaseHashModel): bio="Python developer, wanna work at Redis, Inc.", ) - assert "pk" in customer.__fields__ + assert "pk" in customer.model_fields customer = Customer2( id=1, @@ -806,7 +805,7 @@ class Customer2(m.BaseHashModel): bio="This is member 2 who can be quite anxious until you get to know them.", ) - assert "pk" not in customer.__fields__ + assert "pk" not in customer.model_fields @py_test_mark_asyncio @@ -826,7 +825,7 @@ async def test_count(members, m): @py_test_mark_asyncio async def test_type_with_union(members, m): - class TypeWithUnion(m.BaseHashModel): + class TypeWithUnion(m.BaseHashModel, index=True): field: Union[str, int] twu_str = TypeWithUnion(field="hello world") @@ -849,7 +848,7 @@ class TypeWithUnion(m.BaseHashModel): @py_test_mark_asyncio async def test_type_with_uuid(): - class TypeWithUuid(HashModel): + class TypeWithUuid(HashModel, index=True): uuid: uuid.UUID item = TypeWithUuid(uuid=uuid.uuid4()) @@ -894,10 +893,10 @@ async def test_xfix_queries(members, m): @py_test_mark_asyncio async def test_none(): - class ModelWithNoneDefault(HashModel): + class ModelWithNoneDefault(HashModel, index=True): test: Optional[str] = Field(index=True, default=None) - class ModelWithStringDefault(HashModel): + class ModelWithStringDefault(HashModel, index=True): test: Optional[str] = Field(index=True, default="None") await Migrator().run() @@ -915,7 +914,7 @@ class ModelWithStringDefault(HashModel): @py_test_mark_asyncio async def test_update_validation(): - class TestUpdate(HashModel): + class TestUpdate(HashModel, index=True): name: str age: int @@ -936,7 +935,7 @@ class TestUpdate(HashModel): async def test_literals(): from typing import Literal - class TestLiterals(HashModel): + class TestLiterals(HashModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") schema = TestLiterals.redisearch_schema() @@ -963,7 +962,7 @@ class Model(HashModel): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model): + class Child(Model, index=True): other_name: str # is_new: bool = Field(default=True) @@ -988,7 +987,7 @@ class Model(RedisModel, abc.ABC): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model, HashModel): + class Child(Model, HashModel, index=True): other_name: str # is_new: bool = Field(default=True) @@ -1002,3 +1001,56 @@ class Child(Model, HashModel): assert rematerialized.age == 18 assert rematerialized.bio is None + + +@py_test_mark_asyncio +async def test_model_validate_uses_default_values(): + + class ChildCls: + def __init__(self, first_name: str, other_name: str): + self.first_name = first_name + self.other_name = other_name + + class Model(HashModel): + first_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + other_name: str + + child_dict = {"first_name": "Anna", "other_name": "Maria"} + child_cls = ChildCls(**child_dict) + + child_ctor = Child(**child_dict) + + assert child_ctor.first_name == "Anna" + assert child_ctor.age == 18 + assert child_ctor.bio is None + assert child_ctor.other_name == "Maria" + + child_validate = Child.model_validate(child_cls, from_attributes=True) + + assert child_validate.first_name == "Anna" + assert child_validate.age == 18 + assert child_validate.bio is None + assert child_validate.other_name == "Maria" + + +@py_test_mark_asyncio +async def test_model_with_alias_can_be_searched(key_prefix, redis): + class Model(HashModel, index=True): + first_name: str = Field(alias="firstName", index=True) + last_name: str = Field(alias="lastName") + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + model = Model(first_name="Steve", last_name="Lorello") + await model.save() + + rematerialized = await Model.find(Model.first_name == "Steve").first() + assert rematerialized.pk == model.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index a3e4d40f..7064095e 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1,7 +1,6 @@ # type: ignore import abc -import asyncio import dataclasses import datetime import decimal @@ -12,7 +11,6 @@ import pytest import pytest_asyncio -from more_itertools.more import first from aredis_om import ( EmbeddedJsonModel, @@ -45,14 +43,14 @@ class BaseJsonModel(JsonModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Note(EmbeddedJsonModel): + class Note(EmbeddedJsonModel, index=True): # TODO: This was going to be a full-text search example, but # we can't index embedded documents for full-text search in # the preview release. description: str = Field(index=True) created_on: datetime.datetime - class Address(EmbeddedJsonModel): + class Address(EmbeddedJsonModel, index=True): address_line_1: str address_line_2: Optional[str] = None city: str = Field(index=True) @@ -61,15 +59,15 @@ class Address(EmbeddedJsonModel): postal_code: str = Field(index=True) note: Optional[Note] = None - class Item(EmbeddedJsonModel): + class Item(EmbeddedJsonModel, index=True): price: decimal.Decimal name: str = Field(index=True) - class Order(EmbeddedJsonModel): + class Order(EmbeddedJsonModel, index=True): items: List[Item] created_on: datetime.datetime - class Member(BaseJsonModel): + class Member(BaseJsonModel, index=True): first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) @@ -258,7 +256,7 @@ async def test_all_pks(address, m, redis): @py_test_mark_asyncio async def test_all_pks_with_complex_pks(key_prefix): - class City(JsonModel): + class City(JsonModel, index=True): name: str class Meta: @@ -795,7 +793,7 @@ class NumerologyWitch(m.BaseJsonModel): with pytest.raises(RedisModelError): - class ReadingWithPrice(EmbeddedJsonModel): + class ReadingWithPrice(EmbeddedJsonModel, index=True): gold_coins_charged: int = Field(index=True) class TarotWitchWhoCharges(m.BaseJsonModel): @@ -807,7 +805,7 @@ class TarotWitchWhoCharges(m.BaseJsonModel): # The fate of this feature is To Be Determined. readings: List[ReadingWithPrice] - class TarotWitch(m.BaseJsonModel): + class TarotWitch(m.BaseJsonModel, index=True): # We support indexing lists of strings for quality and membership # queries. Sorting is not supported, but is planned. tarot_cards: List[str] = Field(index=True) @@ -829,7 +827,7 @@ async def test_allows_dataclasses(m): class Address: address_line_1: str - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): address: Address address = Address(address_line_1="hey") @@ -843,7 +841,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_dicts(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): address: Dict[str, str] member = ValidMember(address={"address_line_1": "hey"}) @@ -856,7 +854,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_sets(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): friend_ids: Set[int] member = ValidMember(friend_ids={1, 2}) @@ -869,7 +867,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_lists(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): friend_ids: List[int] member = ValidMember(friend_ids=[1, 2]) @@ -922,7 +920,7 @@ async def test_count(members, m): @py_test_mark_asyncio async def test_type_with_union(members, m): - class TypeWithUnion(m.BaseJsonModel): + class TypeWithUnion(m.BaseJsonModel, index=True): field: Union[str, int] twu_str = TypeWithUnion(field="hello world") @@ -945,7 +943,7 @@ class TypeWithUnion(m.BaseJsonModel): @py_test_mark_asyncio async def test_type_with_uuid(): - class TypeWithUuid(JsonModel): + class TypeWithUuid(JsonModel, index=True): uuid: uuid.UUID item = TypeWithUuid(uuid=uuid.uuid4()) @@ -1004,10 +1002,10 @@ async def test_xfix_queries(m): @py_test_mark_asyncio async def test_none(): - class ModelWithNoneDefault(JsonModel): + class ModelWithNoneDefault(JsonModel, index=True): test: Optional[str] = Field(index=True, default=None) - class ModelWithStringDefault(JsonModel): + class ModelWithStringDefault(JsonModel, index=True): test: Optional[str] = Field(index=True, default="None") await Migrator().run() @@ -1025,11 +1023,11 @@ class ModelWithStringDefault(JsonModel): @py_test_mark_asyncio async def test_update_validation(): - class Embedded(EmbeddedJsonModel): + class Embedded(EmbeddedJsonModel, index=True): price: float name: str = Field(index=True) - class TestUpdatesClass(JsonModel): + class TestUpdatesClass(JsonModel, index=True): name: str age: int embedded: Embedded @@ -1056,10 +1054,10 @@ class TestUpdatesClass(JsonModel): @py_test_mark_asyncio async def test_model_with_dict(): - class EmbeddedJsonModelWithDict(EmbeddedJsonModel): + class EmbeddedJsonModelWithDict(EmbeddedJsonModel, index=True): dict: Dict - class ModelWithDict(JsonModel): + class ModelWithDict(JsonModel, index=True): embedded_model: EmbeddedJsonModelWithDict info: Dict @@ -1080,7 +1078,7 @@ class ModelWithDict(JsonModel): @py_test_mark_asyncio async def test_boolean(): - class Example(JsonModel): + class Example(JsonModel, index=True): b: bool = Field(index=True) d: datetime.date = Field(index=True) name: str = Field(index=True) @@ -1103,7 +1101,7 @@ class Example(JsonModel): @py_test_mark_asyncio async def test_int_pk(): - class ModelWithIntPk(JsonModel): + class ModelWithIntPk(JsonModel, index=True): my_id: int = Field(index=True, primary_key=True) await Migrator().run() @@ -1115,7 +1113,7 @@ class ModelWithIntPk(JsonModel): @py_test_mark_asyncio async def test_pagination(): - class Test(JsonModel): + class Test(JsonModel, index=True): id: str = Field(primary_key=True, index=True) num: int = Field(sortable=True, index=True) @@ -1142,7 +1140,7 @@ async def get_page(cls, offset, limit): async def test_literals(): from typing import Literal - class TestLiterals(JsonModel): + class TestLiterals(JsonModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") schema = TestLiterals.redisearch_schema() @@ -1181,7 +1179,7 @@ class Model(JsonModel): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model): + class Child(Model, index=True): is_new: bool = Field(default=True) await Migrator().run() @@ -1205,13 +1203,13 @@ class Model(RedisModel, abc.ABC): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model, JsonModel): + class Child(Model, JsonModel, index=True): is_new: bool = Field(default=True) await Migrator().run() m = Child(first_name="Steve", last_name="Lorello") await m.save() - print(m.age) + assert m.age == 18 rematerialized = await Child.find(Child.pk == m.pk).first() @@ -1223,10 +1221,10 @@ class Child(Model, JsonModel): @py_test_mark_asyncio async def test_merged_model_error(): - class Player(EmbeddedJsonModel): + class Player(EmbeddedJsonModel, index=True): username: str = Field(index=True) - class Game(JsonModel): + class Game(JsonModel, index=True): player1: Optional[Player] player2: Optional[Player] @@ -1235,3 +1233,77 @@ class Game(JsonModel): ) print(q.query) assert q.query == "(@player1_username:{username})| (@player2_username:{username})" + + +@py_test_mark_asyncio +async def test_model_validate_uses_default_values(): + + class ChildCls: + def __init__(self, first_name: str, other_name: str): + self.first_name = first_name + self.other_name = other_name + + class Model(JsonModel): + first_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + other_name: str + + child_dict = {"first_name": "Anna", "other_name": "Maria"} + child_cls = ChildCls(**child_dict) + + child_ctor = Child(**child_dict) + + assert child_ctor.first_name == "Anna" + assert child_ctor.age == 18 + assert child_ctor.bio is None + assert child_ctor.other_name == "Maria" + + child_validate = Child.model_validate(child_cls, from_attributes=True) + + assert child_validate.first_name == "Anna" + assert child_validate.age == 18 + assert child_validate.bio is None + assert child_validate.other_name == "Maria" + + +@py_test_mark_asyncio +async def test_model_raises_error_if_inherited_from_indexed_model(): + class Model(JsonModel, index=True): + pass + + with pytest.raises(RedisModelError): + + class Child(Model): + pass + + +@py_test_mark_asyncio +async def test_non_indexed_model_raises_error_on_save(): + class Model(JsonModel): + pass + + with pytest.raises(RedisModelError): + model = Model() + await model.save() + + +@py_test_mark_asyncio +async def test_model_with_alias_can_be_searched(key_prefix, redis): + class Model(JsonModel, index=True): + first_name: str = Field(alias="firstName", index=True) + last_name: str = Field(alias="lastName") + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + model = Model(first_name="Steve", last_name="Lorello") + await model.save() + + rematerialized = await Model.find(Model.first_name == "Steve").first() + assert rematerialized.pk == model.pk diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py new file mode 100644 index 00000000..929f7bab --- /dev/null +++ b/tests/test_knn_expression.py @@ -0,0 +1,65 @@ +# type: ignore +import abc +import struct +from typing import Optional + +import pytest_asyncio + +from aredis_om import Field, JsonModel, KNNExpression, Migrator, VectorFieldOptions + +from .conftest import py_test_mark_asyncio + + +DIMENSIONS = 768 + + +vector_field_options = VectorFieldOptions.flat( + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=DIMENSIONS, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, +) + + +@pytest_asyncio.fixture +async def m(key_prefix, redis): + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class Member(BaseJsonModel, index=True): + name: str + embeddings: list[list[float]] = Field([], vector_options=vector_field_options) + embeddings_score: Optional[float] = None + + await Migrator().run() + + return Member + + +def to_bytes(vectors: list[float]) -> bytes: + return struct.pack(f"<{len(vectors)}f", *vectors) + + +@py_test_mark_asyncio +async def test_vector_field(m: type[JsonModel]): + # Create a new instance of the Member model + vectors = [0.3 for _ in range(DIMENSIONS)] + member = m(name="seth", embeddings=[vectors]) + + # Save the member to Redis + await member.save() + + knn = KNNExpression( + k=1, + vector_field=m.embeddings, + score_field=m.embeddings_score, + reference_vector=to_bytes(vectors), + ) + + query = m.find(knn=knn) + + members = await query.all() + + assert len(members) == 1 + assert members[0].embeddings_score is not None diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 47ebe47f..b8a57a6e 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -22,12 +22,12 @@ class BaseHashModel(HashModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Order(BaseHashModel): + class Order(BaseHashModel, index=True): total: decimal.Decimal currency: str created_on: datetime.datetime - class Member(BaseHashModel): + class Member(BaseHashModel, index=True): first_name: str last_name: str email: str @@ -133,7 +133,7 @@ async def test_saves_model_and_creates_pk(m): def test_raises_error_with_embedded_models(m): - class Address(m.BaseHashModel): + class Address(m.BaseHashModel, index=True): address_line_1: str address_line_2: Optional[str] city: str diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index 12d41a9a..04d42db0 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -1,9 +1,11 @@ import abc import datetime from collections import namedtuple +from typing import Optional import pytest import pytest_asyncio +from pydantic import field_validator from aredis_om import Field, HashModel, Migrator from tests._compat import EmailStr, ValidationError @@ -48,3 +50,18 @@ def test_email_str(m): age=38, join_date=today, ) + + +def test_validator_sets_value_on_init(): + value = "bar" + + class ModelWithValidator(HashModel): + field: Optional[str] = Field(default=None, index=True) + + @field_validator("field", mode="after") + def set_field(cls, v): + return value + + m = ModelWithValidator(field="foo") + + assert m.field == value