From ddf335468771aacac0a8759ffb986bd7906548cf Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 18 Apr 2025 15:22:38 -0400 Subject: [PATCH 01/17] remove pydantic v1 references --- aredis_om/__init__.py | 22 ++++ aredis_om/_compat.py | 99 -------------- aredis_om/async_redis.py | 3 + aredis_om/checks.py | 1 - aredis_om/model/encoders.py | 6 +- aredis_om/model/migrations/migrator.py | 2 +- aredis_om/model/model.py | 170 ++++++++----------------- aredis_om/sync_redis.py | 3 + pyproject.toml | 4 +- tests/_compat.py | 11 +- 10 files changed, 89 insertions(+), 232 deletions(-) delete mode 100644 aredis_om/_compat.py diff --git a/aredis_om/__init__.py b/aredis_om/__init__.py index 813e3b04..debe3f77 100644 --- a/aredis_om/__init__.py +++ b/aredis_om/__init__.py @@ -16,3 +16,25 @@ RedisModelError, VectorFieldOptions, ) + + +__all__ = [ + "redis", + "get_redis_connection", + "Field", + "HashModel", + "JsonModel", + "EmbeddedJsonModel", + "RedisModel", + "FindQuery", + "KNNExpression", + "VectorFieldOptions", + "has_redis_json", + "has_redisearch", + "MigrationError", + "Migrator", + "RedisModelError", + "NotFoundError", + "QueryNotSupportedError", + "QuerySyntaxError", +] diff --git a/aredis_om/_compat.py b/aredis_om/_compat.py deleted file mode 100644 index 07dc2824..00000000 --- a/aredis_om/_compat.py +++ /dev/null @@ -1,99 +0,0 @@ -from dataclasses import dataclass, is_dataclass -from typing import ( - Any, - Callable, - Deque, - Dict, - FrozenSet, - List, - Mapping, - Sequence, - Set, - Tuple, - Type, - Union, -) - -from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Annotated, Literal, get_args, get_origin - - -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - -if PYDANTIC_V2: - - def use_pydantic_2_plus(): - return True - - from pydantic import BaseModel, TypeAdapter - from pydantic import ValidationError as ValidationError - from pydantic import validator - from pydantic._internal._model_construction import ModelMetaclass - from pydantic._internal._repr import Representation - from pydantic.deprecated.json import ENCODERS_BY_TYPE - from pydantic.fields import FieldInfo - from pydantic.v1.main import validate_model - from pydantic.v1.typing import NoArgAnyCallable - from pydantic_core import PydanticUndefined as Undefined - from pydantic_core import PydanticUndefinedType as UndefinedType - - @dataclass - class ModelField: - field_info: FieldInfo - name: str - mode: Literal["validation", "serialization"] = "validation" - - @property - def alias(self) -> str: - a = self.field_info.alias - return a if a is not None else self.name - - @property - def required(self) -> bool: - return self.field_info.is_required() - - @property - def default(self) -> Any: - return self.get_default() - - @property - def type_(self) -> Any: - return self.field_info.annotation - - def __post_init__(self) -> None: - self._type_adapter: TypeAdapter[Any] = TypeAdapter( - Annotated[self.field_info.annotation, self.field_info] - ) - - def get_default(self) -> Any: - if self.field_info.is_required(): - return Undefined - return self.field_info.get_default(call_default_factory=True) - - def validate( - self, - value: Any, - values: Dict[str, Any] = {}, # noqa: B006 - *, - loc: Tuple[Union[int, str], ...] = (), - ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: - return ( - self._type_adapter.validate_python(value, from_attributes=True), - None, - ) - - def __hash__(self) -> int: - # Each ModelField is unique for our purposes, to allow making a dict from - # ModelField to its JSON Schema. - return id(self) - -else: - from pydantic import BaseModel, validator - from pydantic.fields import FieldInfo, ModelField, Undefined, UndefinedType - from pydantic.json import ENCODERS_BY_TYPE - from pydantic.main import ModelMetaclass, validate_model - from pydantic.typing import NoArgAnyCallable - from pydantic.utils import Representation - - def use_pydantic_2_plus(): - return False diff --git a/aredis_om/async_redis.py b/aredis_om/async_redis.py index 9a98a03b..b5fb289f 100644 --- a/aredis_om/async_redis.py +++ b/aredis_om/async_redis.py @@ -1 +1,4 @@ from redis import asyncio as redis + + +__all__ = ["redis"] diff --git a/aredis_om/checks.py b/aredis_om/checks.py index be2332cf..522f5c38 100644 --- a/aredis_om/checks.py +++ b/aredis_om/checks.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import List from aredis_om.connections import get_redis_connection diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index cb71447c..4fa5e88e 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -31,8 +31,9 @@ from types import GeneratorType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union -from .._compat import ENCODERS_BY_TYPE, BaseModel - +from pydantic import BaseModel +from pydantic.deprecated.json import ENCODERS_BY_TYPE +from pydantic_core import PydanticUndefined SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] @@ -106,6 +107,7 @@ def jsonable_encoder( or (not isinstance(key, str)) or (not key.startswith("_sa")) ) + and value is not PydanticUndefined and (value is not None or not exclude_none) and ((include and key in include) or not exclude or key not in exclude) ): diff --git a/aredis_om/model/migrations/migrator.py b/aredis_om/model/migrations/migrator.py index eb3d1b0b..34aa7c14 100644 --- a/aredis_om/model/migrations/migrator.py +++ b/aredis_om/model/migrations/migrator.py @@ -131,7 +131,7 @@ async def detect_migrations(self): stored_hash = await conn.get(hash_key) if isinstance(stored_hash, bytes): - stored_hash = stored_hash.decode('utf-8') + stored_hash = stored_hash.decode("utf-8") schema_out_of_date = current_hash != stored_hash diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 67769cd3..aaf54e10 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -11,7 +11,6 @@ AbstractSet, Any, Callable, - ClassVar, Dict, List, Literal, @@ -28,24 +27,18 @@ from typing import no_type_check from more_itertools import ichunked +from pydantic import BaseModel, ConfigDict, TypeAdapter, field_validator +from pydantic._internal._model_construction import ModelMetaclass +from pydantic._internal._repr import Representation +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic_core import PydanticUndefined as Undefined +from pydantic_core import PydanticUndefinedType as UndefinedType from redis.commands.json.path import Path from redis.exceptions import ResponseError from typing_extensions import Protocol, get_args, get_origin from ulid import ULID from .. import redis -from .._compat import PYDANTIC_V2, BaseModel -from .._compat import FieldInfo as PydanticFieldInfo -from .._compat import ( - ModelField, - ModelMetaclass, - NoArgAnyCallable, - Representation, - Undefined, - UndefinedType, - validate_model, - validator, -) from ..checks import has_redis_json, has_redisearch from ..connections import get_redis_connection from ..util import ASYNC_MODE @@ -78,7 +71,7 @@ ERRORS_URL = "https://github.com/redis/redis-om-python/blob/main/docs/errors.md" -def get_outer_type(field): +def get_outer_type(field: PydanticFieldInfo): if hasattr(field, "outer_type_"): return field.outer_type_ elif isinstance(field.annotation, type) or is_supported_container_type( @@ -127,9 +120,7 @@ def __str__(self): return str(self.name) -ExpressionOrModelField = Union[ - "Expression", "NegatedExpression", ModelField, PydanticFieldInfo -] +ExpressionOrModelField = Union["Expression", "NegatedExpression", PydanticFieldInfo] def embedded(cls): @@ -288,11 +279,11 @@ def tree(self): @dataclasses.dataclass class KNNExpression: k: int - vector_field: ModelField + vector_field_name: str reference_vector: bytes def __str__(self): - return f"KNN $K @{self.vector_field.name} $knn_ref_vector" + return f"KNN $K @{self.vector_field_name} $knn_ref_vector" @property def query_params(self) -> Dict[str, Union[str, bytes]]: @@ -300,14 +291,16 @@ def query_params(self) -> Dict[str, Union[str, bytes]]: @property def score_field(self) -> str: - return f"__{self.vector_field.name}_score" + return f"__{self.vector_field_name}_score" ExpressionOrNegated = Union[Expression, NegatedExpression] class ExpressionProxy: - def __init__(self, field: ModelField, parents: List[Tuple[str, "RedisModel"]]): + def __init__( + self, field: PydanticFieldInfo, parents: List[Tuple[str, "RedisModel"]] + ): self.field = field self.parents = parents.copy() # Ensure a copy is stored @@ -519,20 +512,14 @@ def validate_sort_fields(self, sort_fields: List[str]): field_name = sort_field.lstrip("-") if self.knn and field_name == self.knn.score_field: continue - if field_name not in self.model.__fields__: # type: ignore + if field_name not in self.model.model_fields: # type: ignore raise QueryNotSupportedError( f"You tried sort by {field_name}, but that field " f"does not exist on the model {self.model}" ) - field_proxy = getattr(self.model, field_name) - if isinstance(field_proxy.field, FieldInfo) or isinstance( - field_proxy.field, PydanticFieldInfo - ): - field_info = field_proxy.field - else: - field_info = field_proxy.field.field_info + field_proxy: ExpressionProxy = getattr(self.model, field_name) - if not getattr(field_info, "sortable", False): + if not getattr(field_proxy.field, "sortable", False): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " f"not define that field as sortable. Docs: {ERRORS_URL}#E2" @@ -541,14 +528,10 @@ def validate_sort_fields(self, sort_fields: List[str]): @staticmethod def resolve_field_type( - field: Union[ModelField, PydanticFieldInfo], op: Operators + field: PydanticFieldInfo, op: Operators ) -> RediSearchFieldTypes: - field_info: Union[FieldInfo, ModelField, PydanticFieldInfo] + field_info: Union[FieldInfo, PydanticFieldInfo] = field - if not hasattr(field, "field_info"): - field_info = field - else: - field_info = field.field_info if getattr(field_info, "primary_key", None) is True: return RediSearchFieldTypes.TAG elif op is Operators.LIKE: @@ -803,15 +786,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: expression.left, NegatedExpression ): result += f"({cls.resolve_redisearch_query(expression.left)})" - elif isinstance(expression.left, ModelField): - field_type = cls.resolve_field_type(expression.left, expression.op) - field_name = expression.left.name - field_info = expression.left.field_info - if not field_info or not getattr(field_info, "index", None): - raise QueryNotSupportedError( - f"You tried to query by a field ({field_name}) " - f"that isn't indexed. Docs: {ERRORS_URL}#E6" - ) elif isinstance(expression.left, FieldInfo): field_type = cls.resolve_field_type(expression.left, expression.op) field_name = expression.left.alias @@ -827,11 +801,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: f"or an expression enclosed in parentheses. Docs: {ERRORS_URL}#E7" ) - if isinstance(expression.left, ModelField) and expression.parents: - # Build field_name using the specific parents for this expression - prefix = "_".join([p[0] for p in expression.parents]) - field_name = f"{prefix}_{field_name}" - right = expression.right if isinstance(right, Expression) or isinstance(right, NegatedExpression): @@ -857,10 +826,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: raise QuerySyntaxError("Could not resolve field type. See docs: TODO") elif not field_info: raise QuerySyntaxError("Could not resolve field info. See docs: TODO") - elif isinstance(right, ModelField): - raise QueryNotSupportedError( - "Comparing fields is not supported. See docs: TODO" - ) else: result += cls.resolve_value( field_name, @@ -1207,7 +1172,7 @@ def schema(self): def Field( default: Any = Undefined, *, - default_factory: Optional[NoArgAnyCallable] = None, + default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None = None, alias: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, @@ -1272,7 +1237,7 @@ def Field( @dataclasses.dataclass class PrimaryKey: name: str - field: ModelField + field: PydanticFieldInfo class BaseMeta(Protocol): @@ -1311,7 +1276,7 @@ class ModelMeta(ModelMetaclass): def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop("Meta", None) - new_class = super().__new__(cls, name, bases, attrs, **kwargs) + new_class: RedisModel = super().__new__(cls, name, bases, attrs, **kwargs) # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while @@ -1347,7 +1312,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) - for field_name, field in new_class.__fields__.items(): + for field_name, field in new_class.model_fields.items(): if not isinstance(field, FieldInfo): for base_candidate in bases: if hasattr(base_candidate, field_name): @@ -1427,7 +1392,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 return new_class -def outer_type_or_annotation(field): +def outer_type_or_annotation(field: FieldInfo): if hasattr(field, "outer_type_"): return field.outer_type_ elif not hasattr(field.annotation, "__args__"): @@ -1442,23 +1407,11 @@ def outer_type_or_annotation(field): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) - if PYDANTIC_V2: - ConfigDict: ClassVar - Meta = DefaultMeta - if PYDANTIC_V2: - from pydantic import ConfigDict - - model_config = ConfigDict( - from_attributes=True, arbitrary_types_allowed=True, extra="allow" - ) - else: - - class Config: - orm_mode = True - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict( + from_attributes=True, arbitrary_types_allowed=True, extra="allow" + ) def __init__(__pydantic_self__, **data: Any) -> None: __pydantic_self__.validate_primary_key() @@ -1519,7 +1472,7 @@ async def expire( # TODO: Wrap any Redis response errors in a custom exception? await db.expire(self.key(), num_seconds) - @validator("pk", always=True, allow_reuse=True) + @field_validator("pk", mode="after") def validate_pk(cls, v): if not v or isinstance(v, ExpressionProxy): v = cls._meta.primary_key_creator_cls().create_pk() @@ -1529,26 +1482,23 @@ def validate_pk(cls, v): def validate_primary_key(cls): """Check for a primary key. We need one (and only one).""" primary_keys = 0 - for name, field in cls.__fields__.items(): - if not hasattr(field, "field_info"): - if ( - not isinstance(field, FieldInfo) - and hasattr(field, "metadata") - and len(field.metadata) > 0 - and isinstance(field.metadata[0], FieldInfo) - ): - field_info = field.metadata[0] - else: - field_info = field + for name, field in cls.model_fields.items(): + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + field_info = field.metadata[0] else: - field_info = field.field_info + field_info = field if getattr(field_info, "primary_key", None) is True: primary_keys += 1 if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") elif primary_keys == 2: - cls.__fields__.pop("pk") + cls.model_fields.pop("pk") elif primary_keys > 2: raise RedisModelError("You must define only one primary key for a model") @@ -1674,16 +1624,8 @@ def redisearch_schema(cls): raise NotImplementedError def check(self): - """Run all validations.""" - if not PYDANTIC_V2: - *_, validation_error = validate_model(self.__class__, self.__dict__) - if validation_error: - raise validation_error - else: - from pydantic import TypeAdapter - - adapter = TypeAdapter(self.__class__) - adapter.validate_python(self.__dict__) + adapter = TypeAdapter(self.__class__) + adapter.validate_python(self.__dict__) class HashModel(RedisModel, abc.ABC): @@ -1710,7 +1652,7 @@ def __init_subclass__(cls, **kwargs): f"HashModels cannot index dataclass fields. Field: {name}" ) - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): outer_type = outer_type_or_annotation(field) origin = get_origin(outer_type) if origin: @@ -1763,7 +1705,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": if not document: raise NotFoundError try: - result = cls.parse_obj(document) + result = cls.model_validate(document) except TypeError as e: log.warning( f'Could not parse Redis response. Error was: "{e}". Probably, the ' @@ -1772,7 +1714,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": f"model class ({cls.__class__}. Encoding: {cls.Meta.encoding}." ) document = decode_redis_value(document, cls.Meta.encoding) - result = cls.parse_obj(document) + result = cls.model_validate(document) return result @classmethod @@ -1806,7 +1748,7 @@ async def update(self, **field_values): def schema_for_fields(cls): schema_parts = [] - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): # TODO: Merge this code with schema_for_type()? _type = outer_type_or_annotation(field) is_subscripted_type = get_origin(_type) @@ -1819,10 +1761,7 @@ def schema_for_fields(cls): ): field = field.metadata[0] - if not hasattr(field, "field_info"): - field_info = field - else: - field_info = field.field_info + field_info = field if getattr(field_info, "primary_key", None) is True: if issubclass(_type, str): @@ -1893,7 +1832,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo): schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" elif issubclass(typ, RedisModel): sub_fields = [] - for embedded_name, field in typ.__fields__.items(): + for embedded_name, field in typ.model_fields.items(): sub_fields.append( cls.schema_for_type( f"{name}_{embedded_name}", field.outer_type_, field.field_info @@ -1930,7 +1869,7 @@ async def save( db = self._get_db(pipeline) # TODO: Wrap response errors in a custom exception? - await db.json().set(self.key(), Path.root_path(), json.loads(self.json())) + await db.json().set(self.key(), Path.root_path(), self.model_dump(mode="json")) return self @classmethod @@ -1990,7 +1929,7 @@ def schema_for_fields(cls): schema_parts = [] json_path = "$" fields = dict() - for name, field in cls.__fields__.items(): + for name, field in cls.model_fields.items(): fields[name] = field for name, field in cls.__dict__.items(): if isinstance(field, FieldInfo): @@ -2015,10 +1954,8 @@ def schema_for_fields(cls): ): field = field.metadata[0] - if hasattr(field, "field_info"): - field_info = field.field_info - else: - field_info = field + field_info = field + if getattr(field_info, "primary_key", None) is True: if issubclass(_type, str): redisearch_field = f"$.{name} AS {name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}" @@ -2117,12 +2054,11 @@ def schema_for_type( parent_type=field_type, ) elif field_is_model: + typ: type[RedisModel] = typ name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] - for embedded_name, field in typ.__fields__.items(): - if hasattr(field, "field_info"): - field_info = field.field_info - elif ( + for embedded_name, field in typ.model_fields.items(): + if ( hasattr(field, "metadata") and len(field.metadata) > 0 and isinstance(field.metadata[0], FieldInfo) diff --git a/aredis_om/sync_redis.py b/aredis_om/sync_redis.py index 1a472c13..ab1a8546 100644 --- a/aredis_om/sync_redis.py +++ b/aredis_om/sync_redis.py @@ -1 +1,4 @@ import redis + + +__all__ = ["redis"] diff --git a/pyproject.toml b/pyproject.toml index aac11ba9..de8e622f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redis-om" -version = "0.3.5" +version = "0.4.0" description = "Object mappings, and more, for Redis." authors = ["Redis OSS "] maintainers = ["Redis OSS "] @@ -37,7 +37,7 @@ include=[ [tool.poetry.dependencies] python = ">=3.8,<4.0" redis = ">=3.5.3,<6.0.0" -pydantic = ">=1.10.2,<3.0.0" +pydantic = ">=2.0.0,<3.0.0" click = "^8.0.1" types-redis = ">=3.5.9,<5.0.0" python-ulid = "^1.0.3" diff --git a/tests/_compat.py b/tests/_compat.py index 1cd55bf2..9b8f2856 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,10 +1 @@ -from aredis_om._compat import PYDANTIC_V2, use_pydantic_2_plus - - -if not use_pydantic_2_plus() and PYDANTIC_V2: - from pydantic.v1 import EmailStr, ValidationError -elif PYDANTIC_V2: - from pydantic import EmailStr, PositiveInt, ValidationError - -else: - from pydantic import EmailStr, ValidationError +from pydantic import EmailStr, PositiveInt, ValidationError From 49077bbb45c2e9e240f4ac7981152ddf0aa24a7f Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 21 Apr 2025 12:08:07 -0400 Subject: [PATCH 02/17] set validate_default=True --- aredis_om/model/encoders.py | 2 +- aredis_om/model/model.py | 60 +++++-------------------------------- tests/test_hash_model.py | 4 +-- 3 files changed, 11 insertions(+), 55 deletions(-) diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index 4fa5e88e..e29491e8 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -73,7 +73,7 @@ def jsonable_encoder( encoder = getattr(obj.__config__, "json_encoders", {}) if custom_encoder: encoder.update(custom_encoder) - obj_dict = obj.dict( + obj_dict = obj.model_dump( include=include, # type: ignore # in Pydantic exclude=exclude, # type: ignore # in Pydantic by_alias=by_alias, diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index aaf54e10..bda383cb 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -7,8 +7,8 @@ from copy import copy from enum import Enum from functools import reduce +from typing_extensions import Unpack from typing import ( - AbstractSet, Any, Callable, Dict, @@ -30,7 +30,7 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter, field_validator from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._repr import Representation -from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import FieldInfo as PydanticFieldInfo, _FieldInfoInputs from pydantic_core import PydanticUndefined as Undefined from pydantic_core import PydanticUndefinedType as UndefinedType from redis.commands.json.path import Path @@ -155,7 +155,7 @@ def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any obj = getattr(obj, sub_field) return - if field_name not in model.__fields__: # type: ignore + if field_name not in model.model_fields: # type: ignore raise QuerySyntaxError( f"The field {field_name} does not exist on the model {model.__name__}" ) @@ -1170,66 +1170,22 @@ def schema(self): def Field( - default: Any = Undefined, - *, - default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None = None, - alias: Optional[str] = None, - title: Optional[str] = None, - description: Optional[str] = None, - exclude: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - include: Union[ - AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any - ] = None, - const: Optional[bool] = None, - gt: Optional[float] = None, - ge: Optional[float] = None, - lt: Optional[float] = None, - le: Optional[float] = None, - multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - min_length: Optional[int] = None, - max_length: Optional[int] = None, - allow_mutation: bool = True, - regex: Optional[str] = None, primary_key: bool = False, sortable: Union[bool, UndefinedType] = Undefined, case_sensitive: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, - schema_extra: Optional[Dict[str, Any]] = None, + **kwargs: Unpack[_FieldInfoInputs], ) -> Any: - current_schema_extra = schema_extra or {} field_info = FieldInfo( - default, - default_factory=default_factory, - alias=alias, - title=title, - description=description, - exclude=exclude, - include=include, - const=const, - gt=gt, - ge=ge, - lt=lt, - le=le, - multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, - min_length=min_length, - max_length=max_length, - allow_mutation=allow_mutation, - regex=regex, + **kwargs, primary_key=primary_key, sortable=sortable, case_sensitive=case_sensitive, index=index, full_text_search=full_text_search, vector_options=vector_options, - **current_schema_extra, ) return field_info @@ -1410,7 +1366,7 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): Meta = DefaultMeta model_config = ConfigDict( - from_attributes=True, arbitrary_types_allowed=True, extra="allow" + from_attributes=True, arbitrary_types_allowed=True, extra="allow", validate_default=True ) def __init__(__pydantic_self__, **data: Any) -> None: @@ -1677,7 +1633,7 @@ async def save( ) -> "Model": self.check() db = self._get_db(pipeline) - document = jsonable_encoder(self.dict()) + document = jsonable_encoder(self.model_dump()) # filter out values which are `None` because they are not valid in a HSET document = {k: v for k, v in document.items() if v is not None} @@ -1915,7 +1871,7 @@ async def get(cls: Type["Model"], pk: Any) -> "Model": document = json.dumps(await cls.db().json().get(cls.make_key(pk))) if document == "null": raise NotFoundError - return cls.parse_raw(document) + return cls.model_validate_json(document) @classmethod def redisearch_schema(cls): diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 12df8cda..c575897e 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -797,7 +797,7 @@ class Customer2(m.BaseHashModel): bio="Python developer, wanna work at Redis, Inc.", ) - assert "pk" in customer.__fields__ + assert "pk" in customer.model_fields customer = Customer2( id=1, @@ -806,7 +806,7 @@ class Customer2(m.BaseHashModel): bio="This is member 2 who can be quite anxious until you get to know them.", ) - assert "pk" not in customer.__fields__ + assert "pk" not in customer.model_fields @py_test_mark_asyncio From 9ffa5eda30f088d5bd7bfbc18e78eeea0e2e7eb8 Mon Sep 17 00:00:00 2001 From: jmmenard <89422205+jmmenard@users.noreply.github.com> Date: Mon, 21 Apr 2025 15:01:33 -0400 Subject: [PATCH 03/17] Remove unused imports --- tests/test_json_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_json_model.py b/tests/test_json_model.py index a3e4d40f..0f7b7f0a 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1,7 +1,6 @@ # type: ignore import abc -import asyncio import dataclasses import datetime import decimal @@ -12,7 +11,7 @@ import pytest import pytest_asyncio -from more_itertools.more import first + from aredis_om import ( EmbeddedJsonModel, From 3676634f9914e8b1049f9c4d3c5eaa641334aeec Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Thu, 24 Apr 2025 14:35:56 -0400 Subject: [PATCH 04/17] update models to specify when a model is final --- .vscode/settings.json | 4 + aredis_om/model/model.py | 134 +++++++++++++++------------- tests/test_find_query.py | 8 +- tests/test_hash_model.py | 57 +++++++++--- tests/test_json_model.py | 112 ++++++++++++++++------- tests/test_oss_redis_features.py | 6 +- tests/test_pydantic_integrations.py | 17 ++++ 7 files changed, 226 insertions(+), 112 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..2e783ba7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, +} diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index bda383cb..72f73439 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -7,7 +7,6 @@ from copy import copy from enum import Enum from functools import reduce -from typing_extensions import Unpack from typing import ( Any, Callable, @@ -30,12 +29,13 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter, field_validator from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._repr import Representation -from pydantic.fields import FieldInfo as PydanticFieldInfo, _FieldInfoInputs +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import _FieldInfoInputs from pydantic_core import PydanticUndefined as Undefined from pydantic_core import PydanticUndefinedType as UndefinedType from redis.commands.json.path import Path from redis.exceptions import ResponseError -from typing_extensions import Protocol, get_args, get_origin +from typing_extensions import Protocol, Unpack, get_args, get_origin from ulid import ULID from .. import redis @@ -280,6 +280,7 @@ def tree(self): class KNNExpression: k: int vector_field_name: str + score_field_name: str reference_vector: bytes def __str__(self): @@ -291,7 +292,7 @@ def query_params(self) -> Dict[str, Union[str, bytes]]: @property def score_field(self) -> str: - return f"__{self.vector_field_name}_score" + return self.score_field_name or f"_{self.vector_field_name}_score" ExpressionOrNegated = Union[Expression, NegatedExpression] @@ -1176,10 +1177,10 @@ def Field( index: Union[bool, UndefinedType] = Undefined, full_text_search: Union[bool, UndefinedType] = Undefined, vector_options: Optional[VectorFieldOptions] = None, - **kwargs: Unpack[_FieldInfoInputs], + **kwargs: Unpack[_FieldInfoInputs], ) -> Any: field_info = FieldInfo( - **kwargs, + **kwargs, primary_key=primary_key, sortable=sortable, case_sensitive=case_sensitive, @@ -1196,6 +1197,10 @@ class PrimaryKey: field: PydanticFieldInfo +class RedisOmConfig(ConfigDict): + index: bool | None + + class BaseMeta(Protocol): global_key_prefix: str model_key_prefix: str @@ -1230,9 +1235,30 @@ class DefaultMeta: class ModelMeta(ModelMetaclass): _meta: BaseMeta + model_config: RedisOmConfig + model_fields: Dict[str, FieldInfo] # type: ignore[assignment] + def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = attrs.pop("Meta", None) - new_class: RedisModel = super().__new__(cls, name, bases, attrs, **kwargs) + + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: Set[str] = { + key + for key in dir(ConfigDict) + if not ( + key.startswith("__") and key.endswith("__") + ) # skip dunder methods and attributes + } + + config_kwargs = { + key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs + } + + new_class: RedisModel = super().__new__( + cls, name, bases, attrs, **config_kwargs + ) # The fact that there is a Meta field and _meta field is important: a # user may have given us a Meta object with their configuration, while @@ -1241,13 +1267,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 meta = meta or getattr(new_class, "Meta", None) base_meta = getattr(new_class, "_meta", None) - if len(bases) >= 1: - for base_index in range(len(bases)): - model_fields = getattr(bases[base_index], "model_fields", []) - for f_name in model_fields: - field = model_fields[f_name] - new_class.model_fields[f_name] = field - if meta and meta != DefaultMeta and meta != base_meta: new_class.Meta = meta new_class._meta = meta @@ -1266,49 +1285,35 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta + if new_class.model_config.get("index", None) is True: + raise RedisModelError( + f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" + ) + # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) - for field_name, field in new_class.model_fields.items(): - if not isinstance(field, FieldInfo): - for base_candidate in bases: - if hasattr(base_candidate, field_name): - inner_field = getattr(base_candidate, field_name) - if hasattr(inner_field, "field") and isinstance( - getattr(inner_field, "field"), FieldInfo - ): - field.metadata.append(getattr(inner_field, "field")) - field = getattr(inner_field, "field") - - if not field.alias: - field.alias = field_name - setattr(new_class, field_name, ExpressionProxy(field, [])) - annotation = new_class.get_annotations().get(field_name) - if annotation: - new_class.__annotations__[field_name] = Union[ - annotation, ExpressionProxy - ] - else: - new_class.__annotations__[field_name] = ExpressionProxy - # Check if this is our FieldInfo version with extended ORM metadata. - field_info = None - if hasattr(field, "field_info") and isinstance(field.field_info, FieldInfo): - field_info = field.field_info - elif field_name in attrs and isinstance( - attrs.__getitem__(field_name), FieldInfo - ): - field_info = attrs.__getitem__(field_name) - field.field_info = field_info - - if field_info is not None: - if field_info.primary_key: + # Only set if the model is has index=True + if kwargs.get("index", None) == True: + new_class.model_config["index"] = True + for field_name, field in new_class.model_fields.items(): + setattr(new_class, field_name, ExpressionProxy(field, [])) + + # We need to set alias equal the field name here to allow downstream processes to have access to it. + # Processes like the query builder use it. + if not field.alias: + field.alias = field_name + + if getattr(field, "primary_key", None) is True: new_class._meta.primary_key = PrimaryKey( name=field_name, field=field ) - if field_info.vector_options: + if getattr(field, "vector_options", None) is not None: score_attr = f"_{field_name}_score" setattr(new_class, score_attr, None) new_class.__annotations__[score_attr] = Union[float, None] + new_class.model_config["from_attributes"] = True + if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( base_meta, "global_key_prefix", "" @@ -1339,9 +1344,13 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 f"{new_class._meta.model_key_prefix}:index" ) - # Not an abstract model class or embedded model, so we should let the + # Model is indexed and not an abstract model class or embedded model, so we should let the # Migrator create indexes for it. - if abc.ABC not in bases and not getattr(new_class._meta, "embedded", False): + if ( + abc.ABC not in bases + and not getattr(new_class._meta, "embedded", False) + and new_class.model_config.get("index") is True + ): key = f"{new_class.__module__}.{new_class.__qualname__}" model_registry[key] = new_class @@ -1366,21 +1375,15 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): Meta = DefaultMeta model_config = ConfigDict( - from_attributes=True, arbitrary_types_allowed=True, extra="allow", validate_default=True + from_attributes=True, + arbitrary_types_allowed=True, + extra="allow", + validate_default=True, ) def __init__(__pydantic_self__, **data: Any) -> None: __pydantic_self__.validate_primary_key() - missing_fields = __pydantic_self__.model_fields.keys() - data.keys() - {"pk"} - - kwargs = data.copy() - - # This is a hack, we need to manually make sure we are setting up defaults correctly when we encounter them - # because inheritance apparently won't cover that in pydantic 2.0. - for field in missing_fields: - default_value = __pydantic_self__.model_fields.get(field).default # type: ignore - kwargs[field] = default_value - super().__init__(**kwargs) + super().__init__(**data) def __lt__(self, other): """Default sort: compare primary key of models.""" @@ -1388,6 +1391,12 @@ def __lt__(self, other): def key(self): """Return the Redis key for this model.""" + if self.model_config.get("index", False) is not True: + raise RedisModelError( + "You cannot create a key on a model that is not indexed. " + f"Update your class with index=True: class {self.__class__.__name__}(RedisModel, index=True):" + ) + if hasattr(self._meta.primary_key.field, "name"): pk = getattr(self, self._meta.primary_key.field.name) else: @@ -1932,7 +1941,7 @@ def schema_for_type( json_path: str, name: str, name_prefix: str, - typ: Any, + typ: type[RedisModel] | Any, field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: @@ -2010,7 +2019,6 @@ def schema_for_type( parent_type=field_type, ) elif field_is_model: - typ: type[RedisModel] = typ name_prefix = f"{name_prefix}_{name}" if name_prefix else name sub_fields = [] for embedded_name, field in typ.model_fields.items(): diff --git a/tests/test_find_query.py b/tests/test_find_query.py index ecd14e4b..624f2ebd 100644 --- a/tests/test_find_query.py +++ b/tests/test_find_query.py @@ -51,7 +51,7 @@ class Note(EmbeddedJsonModel): description: str = Field(index=True) created_on: datetime.datetime - class Address(EmbeddedJsonModel): + class Address(EmbeddedJsonModel, index=True): address_line_1: str address_line_2: Optional[str] = None city: str = Field(index=True) @@ -60,15 +60,15 @@ class Address(EmbeddedJsonModel): postal_code: str = Field(index=True) note: Optional[Note] = None - class Item(EmbeddedJsonModel): + class Item(EmbeddedJsonModel, index=True): price: decimal.Decimal name: str = Field(index=True) - class Order(EmbeddedJsonModel): + class Order(EmbeddedJsonModel, index=True): items: List[Item] created_on: datetime.datetime - class Member(BaseJsonModel): + class Member(BaseJsonModel, index=True): first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index c575897e..ecbfe3d9 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -42,12 +42,12 @@ class BaseHashModel(HashModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Order(BaseHashModel): + class Order(BaseHashModel, index=True): total: decimal.Decimal currency: str created_on: datetime.datetime - class Member(BaseHashModel): + class Member(BaseHashModel, index=True): id: int = Field(index=True, primary_key=True) first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) @@ -177,7 +177,6 @@ async def test_full_text_search_queries(members, m): @py_test_mark_asyncio -@pytest.mark.xfail(strict=False) async def test_pagination_queries(members, m): member1, member2, member3 = members @@ -524,7 +523,7 @@ async def test_all_pks(m): @py_test_mark_asyncio async def test_all_pks_with_complex_pks(key_prefix): - class City(HashModel): + class City(HashModel, index=True): name: str class Meta: @@ -826,7 +825,7 @@ async def test_count(members, m): @py_test_mark_asyncio async def test_type_with_union(members, m): - class TypeWithUnion(m.BaseHashModel): + class TypeWithUnion(m.BaseHashModel, index=True): field: Union[str, int] twu_str = TypeWithUnion(field="hello world") @@ -849,7 +848,7 @@ class TypeWithUnion(m.BaseHashModel): @py_test_mark_asyncio async def test_type_with_uuid(): - class TypeWithUuid(HashModel): + class TypeWithUuid(HashModel, index=True): uuid: uuid.UUID item = TypeWithUuid(uuid=uuid.uuid4()) @@ -894,10 +893,10 @@ async def test_xfix_queries(members, m): @py_test_mark_asyncio async def test_none(): - class ModelWithNoneDefault(HashModel): + class ModelWithNoneDefault(HashModel, index=True): test: Optional[str] = Field(index=True, default=None) - class ModelWithStringDefault(HashModel): + class ModelWithStringDefault(HashModel, index=True): test: Optional[str] = Field(index=True, default="None") await Migrator().run() @@ -915,7 +914,7 @@ class ModelWithStringDefault(HashModel): @py_test_mark_asyncio async def test_update_validation(): - class TestUpdate(HashModel): + class TestUpdate(HashModel, index=True): name: str age: int @@ -936,7 +935,7 @@ class TestUpdate(HashModel): async def test_literals(): from typing import Literal - class TestLiterals(HashModel): + class TestLiterals(HashModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") schema = TestLiterals.redisearch_schema() @@ -963,7 +962,7 @@ class Model(HashModel): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model): + class Child(Model, index=True): other_name: str # is_new: bool = Field(default=True) @@ -988,7 +987,7 @@ class Model(RedisModel, abc.ABC): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model, HashModel): + class Child(Model, HashModel, index=True): other_name: str # is_new: bool = Field(default=True) @@ -1002,3 +1001,37 @@ class Child(Model, HashModel): assert rematerialized.age == 18 assert rematerialized.bio is None + + +@py_test_mark_asyncio +async def test_model_validate_uses_default_values(): + + class ChildCls: + def __init__(self, first_name: str, other_name: str): + self.first_name = first_name + self.other_name = other_name + + class Model(HashModel): + first_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + other_name: str + + child_dict = {"first_name": "Anna", "other_name": "Maria"} + child_cls = ChildCls(**child_dict) + + child_ctor = Child(**child_dict) + + assert child_ctor.first_name == "Anna" + assert child_ctor.age == 18 + assert child_ctor.bio is None + assert child_ctor.other_name == "Maria" + + child_validate = Child.model_validate(child_cls, from_attributes=True) + + assert child_validate.first_name == "Anna" + assert child_validate.age == 18 + assert child_validate.bio is None + assert child_validate.other_name == "Maria" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 0f7b7f0a..04d8c2cf 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -11,7 +11,7 @@ import pytest import pytest_asyncio - +from pydantic import field_validator from aredis_om import ( EmbeddedJsonModel, @@ -44,14 +44,14 @@ class BaseJsonModel(JsonModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Note(EmbeddedJsonModel): + class Note(EmbeddedJsonModel, index=True): # TODO: This was going to be a full-text search example, but # we can't index embedded documents for full-text search in # the preview release. description: str = Field(index=True) created_on: datetime.datetime - class Address(EmbeddedJsonModel): + class Address(EmbeddedJsonModel, index=True): address_line_1: str address_line_2: Optional[str] = None city: str = Field(index=True) @@ -60,15 +60,15 @@ class Address(EmbeddedJsonModel): postal_code: str = Field(index=True) note: Optional[Note] = None - class Item(EmbeddedJsonModel): + class Item(EmbeddedJsonModel, index=True): price: decimal.Decimal name: str = Field(index=True) - class Order(EmbeddedJsonModel): + class Order(EmbeddedJsonModel, index=True): items: List[Item] created_on: datetime.datetime - class Member(BaseJsonModel): + class Member(BaseJsonModel, index=True): first_name: str = Field(index=True, case_sensitive=True) last_name: str = Field(index=True) email: Optional[EmailStr] = Field(index=True, default=None) @@ -257,7 +257,7 @@ async def test_all_pks(address, m, redis): @py_test_mark_asyncio async def test_all_pks_with_complex_pks(key_prefix): - class City(JsonModel): + class City(JsonModel, index=True): name: str class Meta: @@ -794,7 +794,7 @@ class NumerologyWitch(m.BaseJsonModel): with pytest.raises(RedisModelError): - class ReadingWithPrice(EmbeddedJsonModel): + class ReadingWithPrice(EmbeddedJsonModel, index=True): gold_coins_charged: int = Field(index=True) class TarotWitchWhoCharges(m.BaseJsonModel): @@ -806,7 +806,7 @@ class TarotWitchWhoCharges(m.BaseJsonModel): # The fate of this feature is To Be Determined. readings: List[ReadingWithPrice] - class TarotWitch(m.BaseJsonModel): + class TarotWitch(m.BaseJsonModel, index=True): # We support indexing lists of strings for quality and membership # queries. Sorting is not supported, but is planned. tarot_cards: List[str] = Field(index=True) @@ -828,7 +828,7 @@ async def test_allows_dataclasses(m): class Address: address_line_1: str - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): address: Address address = Address(address_line_1="hey") @@ -842,7 +842,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_dicts(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): address: Dict[str, str] member = ValidMember(address={"address_line_1": "hey"}) @@ -855,7 +855,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_sets(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): friend_ids: Set[int] member = ValidMember(friend_ids={1, 2}) @@ -868,7 +868,7 @@ class ValidMember(m.BaseJsonModel): @py_test_mark_asyncio async def test_allows_and_serializes_lists(m): - class ValidMember(m.BaseJsonModel): + class ValidMember(m.BaseJsonModel, index=True): friend_ids: List[int] member = ValidMember(friend_ids=[1, 2]) @@ -921,7 +921,7 @@ async def test_count(members, m): @py_test_mark_asyncio async def test_type_with_union(members, m): - class TypeWithUnion(m.BaseJsonModel): + class TypeWithUnion(m.BaseJsonModel, index=True): field: Union[str, int] twu_str = TypeWithUnion(field="hello world") @@ -944,7 +944,7 @@ class TypeWithUnion(m.BaseJsonModel): @py_test_mark_asyncio async def test_type_with_uuid(): - class TypeWithUuid(JsonModel): + class TypeWithUuid(JsonModel, index=True): uuid: uuid.UUID item = TypeWithUuid(uuid=uuid.uuid4()) @@ -1003,10 +1003,10 @@ async def test_xfix_queries(m): @py_test_mark_asyncio async def test_none(): - class ModelWithNoneDefault(JsonModel): + class ModelWithNoneDefault(JsonModel, index=True): test: Optional[str] = Field(index=True, default=None) - class ModelWithStringDefault(JsonModel): + class ModelWithStringDefault(JsonModel, index=True): test: Optional[str] = Field(index=True, default="None") await Migrator().run() @@ -1024,11 +1024,11 @@ class ModelWithStringDefault(JsonModel): @py_test_mark_asyncio async def test_update_validation(): - class Embedded(EmbeddedJsonModel): + class Embedded(EmbeddedJsonModel, index=True): price: float name: str = Field(index=True) - class TestUpdatesClass(JsonModel): + class TestUpdatesClass(JsonModel, index=True): name: str age: int embedded: Embedded @@ -1055,10 +1055,10 @@ class TestUpdatesClass(JsonModel): @py_test_mark_asyncio async def test_model_with_dict(): - class EmbeddedJsonModelWithDict(EmbeddedJsonModel): + class EmbeddedJsonModelWithDict(EmbeddedJsonModel, index=True): dict: Dict - class ModelWithDict(JsonModel): + class ModelWithDict(JsonModel, index=True): embedded_model: EmbeddedJsonModelWithDict info: Dict @@ -1079,7 +1079,7 @@ class ModelWithDict(JsonModel): @py_test_mark_asyncio async def test_boolean(): - class Example(JsonModel): + class Example(JsonModel, index=True): b: bool = Field(index=True) d: datetime.date = Field(index=True) name: str = Field(index=True) @@ -1102,7 +1102,7 @@ class Example(JsonModel): @py_test_mark_asyncio async def test_int_pk(): - class ModelWithIntPk(JsonModel): + class ModelWithIntPk(JsonModel, index=True): my_id: int = Field(index=True, primary_key=True) await Migrator().run() @@ -1114,7 +1114,7 @@ class ModelWithIntPk(JsonModel): @py_test_mark_asyncio async def test_pagination(): - class Test(JsonModel): + class Test(JsonModel, index=True): id: str = Field(primary_key=True, index=True) num: int = Field(sortable=True, index=True) @@ -1141,7 +1141,7 @@ async def get_page(cls, offset, limit): async def test_literals(): from typing import Literal - class TestLiterals(JsonModel): + class TestLiterals(JsonModel, index=True): flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple") schema = TestLiterals.redisearch_schema() @@ -1180,7 +1180,7 @@ class Model(JsonModel): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model): + class Child(Model, index=True): is_new: bool = Field(default=True) await Migrator().run() @@ -1204,13 +1204,13 @@ class Model(RedisModel, abc.ABC): age: int = Field(default=18) bio: Optional[str] = Field(default=None) - class Child(Model, JsonModel): + class Child(Model, JsonModel, index=True): is_new: bool = Field(default=True) await Migrator().run() m = Child(first_name="Steve", last_name="Lorello") await m.save() - print(m.age) + assert m.age == 18 rematerialized = await Child.find(Child.pk == m.pk).first() @@ -1222,10 +1222,10 @@ class Child(Model, JsonModel): @py_test_mark_asyncio async def test_merged_model_error(): - class Player(EmbeddedJsonModel): + class Player(EmbeddedJsonModel, index=True): username: str = Field(index=True) - class Game(JsonModel): + class Game(JsonModel, index=True): player1: Optional[Player] player2: Optional[Player] @@ -1234,3 +1234,55 @@ class Game(JsonModel): ) print(q.query) assert q.query == "(@player1_username:{username})| (@player2_username:{username})" + + +@py_test_mark_asyncio +async def test_model_validate_uses_default_values(): + + class ChildCls: + def __init__(self, first_name: str, other_name: str): + self.first_name = first_name + self.other_name = other_name + + class Model(JsonModel): + first_name: str + age: int = Field(default=18) + bio: Optional[str] = Field(default=None) + + class Child(Model): + other_name: str + + child_dict = {"first_name": "Anna", "other_name": "Maria"} + child_cls = ChildCls(**child_dict) + + child_ctor = Child(**child_dict) + + assert child_ctor.first_name == "Anna" + assert child_ctor.age == 18 + assert child_ctor.bio is None + assert child_ctor.other_name == "Maria" + + child_validate = Child.model_validate(child_cls, from_attributes=True) + + assert child_validate.first_name == "Anna" + assert child_validate.age == 18 + assert child_validate.bio is None + assert child_validate.other_name == "Maria" + +@py_test_mark_asyncio +async def test_model_raises_error_if_inherited_from_indexed_model(): + class Model(JsonModel, index=True): + pass + + with pytest.raises(RedisModelError): + class Child(Model): + pass + +@py_test_mark_asyncio +async def test_non_indexed_model_raises_error_on_save(): + class Model(JsonModel): + pass + + with pytest.raises(RedisModelError): + model = Model() + await model.save() diff --git a/tests/test_oss_redis_features.py b/tests/test_oss_redis_features.py index 47ebe47f..b8a57a6e 100644 --- a/tests/test_oss_redis_features.py +++ b/tests/test_oss_redis_features.py @@ -22,12 +22,12 @@ class BaseHashModel(HashModel, abc.ABC): class Meta: global_key_prefix = key_prefix - class Order(BaseHashModel): + class Order(BaseHashModel, index=True): total: decimal.Decimal currency: str created_on: datetime.datetime - class Member(BaseHashModel): + class Member(BaseHashModel, index=True): first_name: str last_name: str email: str @@ -133,7 +133,7 @@ async def test_saves_model_and_creates_pk(m): def test_raises_error_with_embedded_models(m): - class Address(m.BaseHashModel): + class Address(m.BaseHashModel, index=True): address_line_1: str address_line_2: Optional[str] city: str diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index 12d41a9a..fa9e376c 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -4,6 +4,7 @@ import pytest import pytest_asyncio +from pydantic import field_validator from aredis_om import Field, HashModel, Migrator from tests._compat import EmailStr, ValidationError @@ -48,3 +49,19 @@ def test_email_str(m): age=38, join_date=today, ) + + + +def test_validator_sets_value_on_init(): + value = "bar" + + class ModelWithValidator(HashModel): + field: str | None = Field(default=None, index=True) + + @field_validator("field", mode="after") + def set_field(cls, v): + return value + + m = ModelWithValidator() + + assert m.field == value From e0f3b12d91b1dd895bd1e1f5754c60cb70e231b4 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Thu, 24 Apr 2025 16:38:35 -0400 Subject: [PATCH 05/17] fix knn setup --- aredis_om/__init__.py | 22 ---------- aredis_om/async_redis.py | 3 -- aredis_om/model/encoders.py | 1 + aredis_om/model/model.py | 36 ++++++---------- aredis_om/sync_redis.py | 3 -- tests/test_json_model.py | 6 ++- tests/test_knn_expression.py | 65 +++++++++++++++++++++++++++++ tests/test_pydantic_integrations.py | 3 +- 8 files changed, 84 insertions(+), 55 deletions(-) create mode 100644 tests/test_knn_expression.py diff --git a/aredis_om/__init__.py b/aredis_om/__init__.py index debe3f77..813e3b04 100644 --- a/aredis_om/__init__.py +++ b/aredis_om/__init__.py @@ -16,25 +16,3 @@ RedisModelError, VectorFieldOptions, ) - - -__all__ = [ - "redis", - "get_redis_connection", - "Field", - "HashModel", - "JsonModel", - "EmbeddedJsonModel", - "RedisModel", - "FindQuery", - "KNNExpression", - "VectorFieldOptions", - "has_redis_json", - "has_redisearch", - "MigrationError", - "Migrator", - "RedisModelError", - "NotFoundError", - "QueryNotSupportedError", - "QuerySyntaxError", -] diff --git a/aredis_om/async_redis.py b/aredis_om/async_redis.py index b5fb289f..9a98a03b 100644 --- a/aredis_om/async_redis.py +++ b/aredis_om/async_redis.py @@ -1,4 +1 @@ from redis import asyncio as redis - - -__all__ = ["redis"] diff --git a/aredis_om/model/encoders.py b/aredis_om/model/encoders.py index e29491e8..236133e7 100644 --- a/aredis_om/model/encoders.py +++ b/aredis_om/model/encoders.py @@ -35,6 +35,7 @@ from pydantic.deprecated.json import ENCODERS_BY_TYPE from pydantic_core import PydanticUndefined + SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 72f73439..3e8ef47e 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -279,20 +279,24 @@ def tree(self): @dataclasses.dataclass class KNNExpression: k: int - vector_field_name: str - score_field_name: str + vector_field: "ExpressionProxy" + score_field: "ExpressionProxy" reference_vector: bytes def __str__(self): - return f"KNN $K @{self.vector_field_name} $knn_ref_vector" + return f"KNN $K @{self.vector_field_name} $knn_ref_vector AS {self.score_field_name}" @property def query_params(self) -> Dict[str, Union[str, bytes]]: return {"K": str(self.k), "knn_ref_vector": self.reference_vector} @property - def score_field(self) -> str: - return self.score_field_name or f"_{self.vector_field_name}_score" + def score_field_name(self) -> str: + return self.score_field.field.alias + + @property + def vector_field_name(self) -> str: + return self.vector_field.field.alias ExpressionOrNegated = Union[Expression, NegatedExpression] @@ -438,7 +442,7 @@ def __init__( if sort_fields: self.sort_fields = self.validate_sort_fields(sort_fields) elif self.knn: - self.sort_fields = [self.knn.score_field] + self.sort_fields = [self.knn.score_field_name] else: self.sort_fields = [] @@ -511,7 +515,7 @@ def query_params(self): def validate_sort_fields(self, sort_fields: List[str]): for sort_field in sort_fields: field_name = sort_field.lstrip("-") - if self.knn and field_name == self.knn.score_field: + if self.knn and field_name == self.knn.score_field_name: continue if field_name not in self.model.model_fields: # type: ignore raise QueryNotSupportedError( @@ -1307,12 +1311,6 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 new_class._meta.primary_key = PrimaryKey( name=field_name, field=field ) - if getattr(field, "vector_options", None) is not None: - score_attr = f"_{field_name}_score" - setattr(new_class, score_attr, None) - new_class.__annotations__[score_attr] = Union[float, None] - - new_class.model_config["from_attributes"] = True if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( @@ -1371,15 +1369,10 @@ def outer_type_or_annotation(field: FieldInfo): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): - pk: Optional[str] = Field(default=None, primary_key=True) + pk: Optional[str] = Field(default=None, primary_key=True, validate_default=True) Meta = DefaultMeta - model_config = ConfigDict( - from_attributes=True, - arbitrary_types_allowed=True, - extra="allow", - validate_default=True, - ) + model_config = ConfigDict(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: __pydantic_self__.validate_primary_key() @@ -1518,9 +1511,6 @@ def to_string(s): if fields.get("$"): json_fields = json.loads(fields.pop("$")) doc = cls(**json_fields) - for k, v in fields.items(): - if k.startswith("__") and k.endswith("_score"): - setattr(doc, k[1:], float(v)) else: doc = cls(**fields) diff --git a/aredis_om/sync_redis.py b/aredis_om/sync_redis.py index ab1a8546..1a472c13 100644 --- a/aredis_om/sync_redis.py +++ b/aredis_om/sync_redis.py @@ -1,4 +1 @@ import redis - - -__all__ = ["redis"] diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 04d8c2cf..4f3bb713 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -11,7 +11,6 @@ import pytest import pytest_asyncio -from pydantic import field_validator from aredis_om import ( EmbeddedJsonModel, @@ -1269,15 +1268,18 @@ class Child(Model): assert child_validate.bio is None assert child_validate.other_name == "Maria" + @py_test_mark_asyncio async def test_model_raises_error_if_inherited_from_indexed_model(): class Model(JsonModel, index=True): - pass + pass with pytest.raises(RedisModelError): + class Child(Model): pass + @py_test_mark_asyncio async def test_non_indexed_model_raises_error_on_save(): class Model(JsonModel): diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py new file mode 100644 index 00000000..04d4a727 --- /dev/null +++ b/tests/test_knn_expression.py @@ -0,0 +1,65 @@ +# type: ignore +import abc +import time + +import pytest_asyncio + +from aredis_om import Field, JsonModel, KNNExpression, Migrator, VectorFieldOptions + +from .conftest import py_test_mark_asyncio + + +vector_field_options = VectorFieldOptions.flat( + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=768, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, +) + + +@pytest_asyncio.fixture +async def m(key_prefix, redis): + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + class Member(BaseJsonModel, index=True): + name: str + embeddings: list[list[float]] = Field([], vector_options=vector_field_options) + embeddings_score: float | None = None + + await Migrator().run() + + return Member + + +@pytest_asyncio.fixture +async def embedding_bytes(): + return b"\x00" * 3072 + + +@py_test_mark_asyncio +async def test_vector_field(m: type[JsonModel], embedding_bytes): + # Create a new instance of the Member model + member = m(name="seth", embeddings=[[0.1, 0.2, 0.3]]) + + # Save the member to Redis + mt = await member.save() + + assert m.get(mt.pk) + + time.sleep(1) + + knn = KNNExpression( + k=1, + vector_field=m.embeddings, + score_field=m.embeddings_score, + reference_vector=embedding_bytes, + ) + + query = m.find() + + members = await query.all() + + assert len(members) == 1 + assert members[0].embeddings_score is not None diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index fa9e376c..38ab35a7 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -51,7 +51,6 @@ def test_email_str(m): ) - def test_validator_sets_value_on_init(): value = "bar" @@ -62,6 +61,6 @@ class ModelWithValidator(HashModel): def set_field(cls, v): return value - m = ModelWithValidator() + m = ModelWithValidator(field="foo") assert m.field == value From 298f2e1d596f189fda1f59abf2f47008209af4d1 Mon Sep 17 00:00:00 2001 From: jmmenard <89422205+jmmenard@users.noreply.github.com> Date: Fri, 25 Apr 2025 09:17:34 -0400 Subject: [PATCH 06/17] Update KNN test query --- tests/test_knn_expression.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index 04d4a727..bdaf3c27 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -1,6 +1,7 @@ # type: ignore import abc import time +import random import pytest_asyncio @@ -41,7 +42,9 @@ async def embedding_bytes(): @py_test_mark_asyncio async def test_vector_field(m: type[JsonModel], embedding_bytes): # Create a new instance of the Member model - member = m(name="seth", embeddings=[[0.1, 0.2, 0.3]]) + dimensions = m.embeddings.field.vector_options.dimension + embeddings = [random.uniform(-1, 1) for _ in range(dimensions)] + member = m(name="seth", embeddings=[embeddings]) # Save the member to Redis mt = await member.save() @@ -57,7 +60,7 @@ async def test_vector_field(m: type[JsonModel], embedding_bytes): reference_vector=embedding_bytes, ) - query = m.find() + query = m.find(knn=knn) members = await query.all() From 721d91e1a73d12171e98d10319703be0edc3ecd0 Mon Sep 17 00:00:00 2001 From: jmmenard <89422205+jmmenard@users.noreply.github.com> Date: Fri, 25 Apr 2025 11:52:04 -0400 Subject: [PATCH 07/17] Update KNN results include score --- aredis_om/model/model.py | 1 + tests/test_knn_expression.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 3e8ef47e..89366edc 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1510,6 +1510,7 @@ def to_string(s): # $ means a json entry if fields.get("$"): json_fields = json.loads(fields.pop("$")) + json_fields.update(fields) doc = cls(**json_fields) else: doc = cls(**fields) diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index bdaf3c27..5b8c65d5 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -1,7 +1,7 @@ # type: ignore import abc import time -import random +import struct import pytest_asyncio @@ -9,10 +9,12 @@ from .conftest import py_test_mark_asyncio +DIMENSIONS = 768 + vector_field_options = VectorFieldOptions.flat( type=VectorFieldOptions.TYPE.FLOAT32, - dimension=768, + dimension=DIMENSIONS, distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, ) @@ -26,7 +28,9 @@ class Meta: class Member(BaseJsonModel, index=True): name: str - embeddings: list[list[float]] = Field([], vector_options=vector_field_options) + embeddings: list[list[float]] | bytes = Field( + [], vector_options=vector_field_options + ) embeddings_score: float | None = None await Migrator().run() @@ -34,17 +38,15 @@ class Member(BaseJsonModel, index=True): return Member -@pytest_asyncio.fixture -async def embedding_bytes(): - return b"\x00" * 3072 +def to_bytes(vectors: list[float]) -> bytes: + return struct.pack(f"<{len(vectors)}f", *vectors) @py_test_mark_asyncio -async def test_vector_field(m: type[JsonModel], embedding_bytes): +async def test_vector_field(m: type[JsonModel]): # Create a new instance of the Member model - dimensions = m.embeddings.field.vector_options.dimension - embeddings = [random.uniform(-1, 1) for _ in range(dimensions)] - member = m(name="seth", embeddings=[embeddings]) + vectors = [0.3 for _ in range(DIMENSIONS)] + member = m(name="seth", embeddings=[vectors]) # Save the member to Redis mt = await member.save() @@ -57,7 +59,7 @@ async def test_vector_field(m: type[JsonModel], embedding_bytes): k=1, vector_field=m.embeddings, score_field=m.embeddings_score, - reference_vector=embedding_bytes, + reference_vector=to_bytes(vectors), ) query = m.find(knn=knn) From 8bc61a185adb4b1bd06f4dc91e4ec6d82a5bb74a Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 25 Apr 2025 13:18:09 -0400 Subject: [PATCH 08/17] remove use of alias for queries --- .vscode/settings.json | 4 --- aredis_om/model/model.py | 51 +++++++++++++++++++----------------- tests/test_hash_model.py | 19 ++++++++++++++ tests/test_json_model.py | 19 ++++++++++++++ tests/test_knn_expression.py | 13 +++------ 5 files changed, 69 insertions(+), 37 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 2e783ba7..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, -} diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 89366edc..87faf57f 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -292,20 +292,18 @@ def query_params(self) -> Dict[str, Union[str, bytes]]: @property def score_field_name(self) -> str: - return self.score_field.field.alias + return self.score_field.field.name @property def vector_field_name(self) -> str: - return self.vector_field.field.alias + return self.vector_field.field.name ExpressionOrNegated = Union[Expression, NegatedExpression] class ExpressionProxy: - def __init__( - self, field: PydanticFieldInfo, parents: List[Tuple[str, "RedisModel"]] - ): + def __init__(self, field: "FieldInfo", parents: List[Tuple[str, "RedisModel"]]): self.field = field self.parents = parents.copy() # Ensure a copy is stored @@ -389,7 +387,7 @@ def __getattr__(self, item): if isinstance(attr, self.__class__): # Clone the parents to ensure isolation new_parents = self.parents.copy() - new_parent = (self.field.alias, outer_type) + new_parent = (self.field.name, outer_type) if new_parent not in new_parents: new_parents.append(new_parent) attr.parents = new_parents @@ -524,17 +522,18 @@ def validate_sort_fields(self, sort_fields: List[str]): ) field_proxy: ExpressionProxy = getattr(self.model, field_name) - if not getattr(field_proxy.field, "sortable", False): + if ( + not field_proxy.field.sortable is True + and not field_proxy.field.index is True + ): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " - f"not define that field as sortable. Docs: {ERRORS_URL}#E2" + f"not define that field as sortable or indexed. Docs: {ERRORS_URL}#E2" ) return sort_fields @staticmethod - def resolve_field_type( - field: PydanticFieldInfo, op: Operators - ) -> RediSearchFieldTypes: + def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldTypes: field_info: Union[FieldInfo, PydanticFieldInfo] = field if getattr(field_info, "primary_key", None) is True: @@ -543,7 +542,7 @@ def resolve_field_type( fts = getattr(field_info, "full_text_search", None) if fts is not True: # Could be PydanticUndefined raise QuerySyntaxError( - f"You tried to do a full-text search on the field '{field.alias}', " + f"You tried to do a full-text search on the field '{field.name}', " f"but the field is not indexed for full-text search. Use the " f"full_text_search=True option. Docs: {ERRORS_URL}#E3" ) @@ -793,7 +792,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: result += f"({cls.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, FieldInfo): field_type = cls.resolve_field_type(expression.left, expression.op) - field_name = expression.left.alias + field_name = expression.left.name field_info = expression.left if not field_info or not getattr(field_info, "index", None): raise QueryNotSupportedError( @@ -1059,6 +1058,8 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): + name: str + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) sortable = kwargs.pop("sortable", Undefined) @@ -1297,20 +1298,22 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) # Only set if the model is has index=True - if kwargs.get("index", None) == True: - new_class.model_config["index"] = True - for field_name, field in new_class.model_fields.items(): + is_indexed = kwargs.get("index", None) is True + new_class.model_config["index"] = is_indexed + + for field_name, field in new_class.model_fields.items(): + if field.__class__ is PydanticFieldInfo: + field = FieldInfo(**field._attributes_set) + setattr(new_class, field_name, field) + + if is_indexed: setattr(new_class, field_name, ExpressionProxy(field, [])) - # We need to set alias equal the field name here to allow downstream processes to have access to it. - # Processes like the query builder use it. - if not field.alias: - field.alias = field_name + # we need to set the field name for use in queries + field.name = field_name - if getattr(field, "primary_key", None) is True: - new_class._meta.primary_key = PrimaryKey( - name=field_name, field=field - ) + if field.primary_key is True: + new_class._meta.primary_key = PrimaryKey(name=field_name, field=field) if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index ecbfe3d9..cf4ffe32 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -1035,3 +1035,22 @@ class Child(Model): assert child_validate.age == 18 assert child_validate.bio is None assert child_validate.other_name == "Maria" + + +@py_test_mark_asyncio +async def test_model_with_alias_can_be_searched(key_prefix, redis): + class Model(HashModel, index=True): + first_name: str = Field(alias="firstName", index=True) + last_name: str = Field(alias="lastName") + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + model = Model(first_name="Steve", last_name="Lorello") + await model.save() + + rematerialized = await Model.find(Model.first_name == "Steve").first() + assert rematerialized.pk == model.pk diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 4f3bb713..7064095e 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1288,3 +1288,22 @@ class Model(JsonModel): with pytest.raises(RedisModelError): model = Model() await model.save() + + +@py_test_mark_asyncio +async def test_model_with_alias_can_be_searched(key_prefix, redis): + class Model(JsonModel, index=True): + first_name: str = Field(alias="firstName", index=True) + last_name: str = Field(alias="lastName") + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + model = Model(first_name="Steve", last_name="Lorello") + await model.save() + + rematerialized = await Model.find(Model.first_name == "Steve").first() + assert rematerialized.pk == model.pk diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index 5b8c65d5..9bd8211c 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -1,7 +1,7 @@ # type: ignore import abc -import time import struct +import time import pytest_asyncio @@ -9,6 +9,7 @@ from .conftest import py_test_mark_asyncio + DIMENSIONS = 768 @@ -28,9 +29,7 @@ class Meta: class Member(BaseJsonModel, index=True): name: str - embeddings: list[list[float]] | bytes = Field( - [], vector_options=vector_field_options - ) + embeddings: list[list[float]] = Field([], vector_options=vector_field_options) embeddings_score: float | None = None await Migrator().run() @@ -49,11 +48,7 @@ async def test_vector_field(m: type[JsonModel]): member = m(name="seth", embeddings=[vectors]) # Save the member to Redis - mt = await member.save() - - assert m.get(mt.pk) - - time.sleep(1) + await member.save() knn = KNNExpression( k=1, From d6e9b2336a83d5935163848581421ceab54685c0 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 25 Apr 2025 13:50:38 -0400 Subject: [PATCH 09/17] fix lint issues --- aredis_om/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 87faf57f..ad3f25b7 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -523,8 +523,8 @@ def validate_sort_fields(self, sort_fields: List[str]): field_proxy: ExpressionProxy = getattr(self.model, field_name) if ( - not field_proxy.field.sortable is True - and not field_proxy.field.index is True + field_proxy.field.sortable is not True + and field_proxy.field.index is not True ): raise QueryNotSupportedError( f"You tried sort by {field_name}, but {self.model} does " From ad07ceed96cb9e02191a60c3d67f0dea4deb7a23 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 25 Apr 2025 13:55:16 -0400 Subject: [PATCH 10/17] fix lint issues --- aredis_om/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index ad3f25b7..f06de216 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -81,7 +81,7 @@ def get_outer_type(field: PydanticFieldInfo): elif not hasattr(field.annotation, "__args__"): return None else: - return field.annotation.__args__[0] + return field.annotation.__args__[0] # type: ignore class RedisModelError(Exception): @@ -1368,7 +1368,7 @@ def outer_type_or_annotation(field: FieldInfo): elif get_origin(field.annotation) == Literal: return str else: - return field.annotation.__args__[0] + return field.annotation.__args__[0] # type: ignore class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): From 623f71445f037c4ccd2363030ade5f2189845cf6 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 25 Apr 2025 14:00:51 -0400 Subject: [PATCH 11/17] fix annotations for python < 3.10 --- aredis_om/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index f06de216..45333eb4 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1203,7 +1203,7 @@ class PrimaryKey: class RedisOmConfig(ConfigDict): - index: bool | None + index: Optional[bool] class BaseMeta(Protocol): @@ -1935,7 +1935,7 @@ def schema_for_type( json_path: str, name: str, name_prefix: str, - typ: type[RedisModel] | Any, + typ: Union[type[RedisModel], Any], field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: From fc66b8057134f6dd096ce6e1ff9ace98860f06df Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Fri, 25 Apr 2025 14:08:24 -0400 Subject: [PATCH 12/17] fix annotations for python < 3.10 --- tests/test_knn_expression.py | 4 ++-- tests/test_pydantic_integrations.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index 9bd8211c..929f7bab 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -1,7 +1,7 @@ # type: ignore import abc import struct -import time +from typing import Optional import pytest_asyncio @@ -30,7 +30,7 @@ class Meta: class Member(BaseJsonModel, index=True): name: str embeddings: list[list[float]] = Field([], vector_options=vector_field_options) - embeddings_score: float | None = None + embeddings_score: Optional[float] = None await Migrator().run() diff --git a/tests/test_pydantic_integrations.py b/tests/test_pydantic_integrations.py index 38ab35a7..04d42db0 100644 --- a/tests/test_pydantic_integrations.py +++ b/tests/test_pydantic_integrations.py @@ -1,6 +1,7 @@ import abc import datetime from collections import namedtuple +from typing import Optional import pytest import pytest_asyncio @@ -55,7 +56,7 @@ def test_validator_sets_value_on_init(): value = "bar" class ModelWithValidator(HashModel): - field: str | None = Field(default=None, index=True) + field: Optional[str] = Field(default=None, index=True) @field_validator("field", mode="after") def set_field(cls, v): From 3ed225ceaa457add4600ec246126069ddd3dc3ee Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 28 Apr 2025 10:56:44 -0400 Subject: [PATCH 13/17] fix issue where certain fields were unintentionally being indexed --- .vscode/settings.json | 4 ++++ aredis_om/model/model.py | 23 ++++++++++++++++------- tests/test_json_model.py | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..2e783ba7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, +} diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 45333eb4..92e3da42 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1060,12 +1060,12 @@ def __dataclass_transform__( class FieldInfo(PydanticFieldInfo): name: str - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + def __init__(self, default: Any = ..., **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) - sortable = kwargs.pop("sortable", Undefined) - case_sensitive = kwargs.pop("case_sensitive", Undefined) - index = kwargs.pop("index", Undefined) - full_text_search = kwargs.pop("full_text_search", Undefined) + sortable = kwargs.pop("sortable", None) + case_sensitive = kwargs.pop("case_sensitive", None) + index = kwargs.pop("index", None) + full_text_search = kwargs.pop("full_text_search", None) vector_options = kwargs.pop("vector_options", None) super().__init__(default=default, **kwargs) self.primary_key = primary_key @@ -1372,7 +1372,13 @@ def outer_type_or_annotation(field: FieldInfo): class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): - pk: Optional[str] = Field(default=None, primary_key=True, validate_default=True) + pk: Optional[str] = Field( + # Indexing for backwards compatibility, we might not want this in the future + default=None, + primary_key=True, + validate_default=True, + index=True, + ) Meta = DefaultMeta model_config = ConfigDict(from_attributes=True) @@ -1939,7 +1945,10 @@ def schema_for_type( field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: - should_index = getattr(field_info, "index", False) + should_index = ( + getattr(field_info, "index", False) is True + or getattr(field_info, "vector_options", None) is not None + ) is_container_type = is_supported_container_type(typ) parent_is_container_type = is_supported_container_type(parent_type) parent_is_model = False diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 7064095e..6d4ea42e 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -6,6 +6,7 @@ import decimal import uuid from collections import namedtuple +from enum import Enum from typing import Dict, List, Optional, Set, Union from unittest import mock @@ -951,6 +952,45 @@ class TypeWithUuid(JsonModel, index=True): await item.save() +@py_test_mark_asyncio +async def test_type_with_enum(): + class TestEnum(Enum): + FOO = "foo" + BAR = "bar" + + class TypeWithEnum(JsonModel, index=True): + enum: TestEnum + + await Migrator().run() + + item = TypeWithEnum(enum=TestEnum.FOO) + + await item.save() + + assert await TypeWithEnum.get(item.pk) == item + + +@py_test_mark_asyncio +async def test_type_with_list_of_enums(key_prefix, redis): + class TestEnum(Enum): + FOO = "foo" + BAR = "bar" + + class BaseWithEnums(JsonModel): + enums: list[TestEnum] + + class TypeWithEnums(BaseWithEnums, index=True): + pass + + await Migrator().run() + + item = TypeWithEnums(enums=[TestEnum.FOO]) + + await item.save() + + assert await TypeWithEnums.get(item.pk) == item + + @py_test_mark_asyncio async def test_xfix_queries(m): await m.Member( From 23213aff7be335070792ab29534532159b5b44f6 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 28 Apr 2025 11:11:11 -0400 Subject: [PATCH 14/17] remove vscode settings --- .vscode/settings.json | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 2e783ba7..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, -} From b45ffd5bf1c648d3bced1c590b889ea74a489a37 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 28 Apr 2025 14:41:35 -0400 Subject: [PATCH 15/17] dynamically set index based on other fields that require indexes --- aredis_om/model/model.py | 27 +++++++++++++++++---------- tests/test_hash_model.py | 4 ++-- tests/test_json_model.py | 15 +++++++++++++-- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 92e3da42..d5c54131 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1290,19 +1290,20 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 ) new_class.Meta = new_class._meta - if new_class.model_config.get("index", None) is True: + is_indexed = kwargs.get("index", None) is True + + if is_indexed and new_class.model_config.get("index", None) is True: raise RedisModelError( f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree" ) + new_class.model_config["index"] = is_indexed + # Create proxies for each model field so that we can use the field # in queries, like Model.get(Model.field_name == 1) # Only set if the model is has index=True - is_indexed = kwargs.get("index", None) is True - new_class.model_config["index"] = is_indexed - for field_name, field in new_class.model_fields.items(): - if field.__class__ is PydanticFieldInfo: + if type(field) is PydanticFieldInfo: field = FieldInfo(**field._attributes_set) setattr(new_class, field_name, field) @@ -1370,6 +1371,15 @@ def outer_type_or_annotation(field: FieldInfo): else: return field.annotation.__args__[0] # type: ignore +def should_index_field(field_info: FieldInfo) -> bool: + # for vector, full text search, and sortable fields, we always have to index + # We could require the user to set index=True, but that would be a breaking change + return ( + getattr(field_info, "index", False) is True + or getattr(field_info, "vector_options", None) is not None + or getattr(field_info, "full_text_search", False) is True + or getattr(field_info, "sortable", False) is True + ) class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field( @@ -1736,7 +1746,7 @@ def schema_for_fields(cls): else: redisearch_field = cls.schema_for_type(name, _type, field_info) schema_parts.append(redisearch_field) - elif getattr(field_info, "index", None) is True: + elif should_index_field(field_info): schema_parts.append(cls.schema_for_type(name, _type, field_info)) elif is_subscripted_type: # Ignore subscripted types (usually containers!) that we don't @@ -1945,10 +1955,7 @@ def schema_for_type( field_info: PydanticFieldInfo, parent_type: Optional[Any] = None, ) -> str: - should_index = ( - getattr(field_info, "index", False) is True - or getattr(field_info, "vector_options", None) is not None - ) + should_index = should_index_field(field_info) is_container_type = is_supported_container_type(typ) parent_is_container_type = is_supported_container_type(parent_type) parent_is_model = False diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index cf4ffe32..e84c12de 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -53,8 +53,8 @@ class Member(BaseHashModel, index=True): last_name: str = Field(index=True) email: str = Field(index=True) join_date: datetime.date - age: int = Field(index=True, sortable=True) - bio: str = Field(index=True, full_text_search=True) + age: int = Field(sortable=True) + bio: str = Field(full_text_search=True) class Meta: model_key_prefix = "member" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 6d4ea42e..4ae1cf54 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -74,7 +74,7 @@ class Member(BaseJsonModel, index=True): email: Optional[EmailStr] = Field(index=True, default=None) join_date: datetime.date age: Optional[PositiveInt] = Field(index=True, default=None) - bio: Optional[str] = Field(index=True, full_text_search=True, default="") + bio: Optional[str] = Field(full_text_search=True, default="") # Creates an embedded model. address: Address @@ -1316,10 +1316,21 @@ class Model(JsonModel, index=True): with pytest.raises(RedisModelError): - class Child(Model): + class Child(Model, index=True): pass +@py_test_mark_asyncio +async def test_model_inherited_from_indexed_model(): + class Model(JsonModel, index=True): + name: str = "Steve" + + class Child(Model): + pass + + assert issubclass(Child, Model) + + @py_test_mark_asyncio async def test_non_indexed_model_raises_error_on_save(): class Model(JsonModel): From 48a092bcfb0c9472aa8caa6a6674f56273d13a42 Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 28 Apr 2025 15:06:34 -0400 Subject: [PATCH 16/17] fix inheritance from indexed model --- aredis_om/model/model.py | 3 ++- tests/test_hash_model.py | 6 +++--- tests/test_json_model.py | 4 ++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index d5c54131..1190b06f 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1394,7 +1394,8 @@ class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): model_config = ConfigDict(from_attributes=True) def __init__(__pydantic_self__, **data: Any) -> None: - __pydantic_self__.validate_primary_key() + if __pydantic_self__.model_config.get("index") is True: + __pydantic_self__.validate_primary_key() super().__init__(**data) def __lt__(self, other): diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index e84c12de..99bc36b4 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -754,7 +754,7 @@ class Address(m.BaseHashModel): @py_test_mark_asyncio async def test_primary_key_model_error(m): - class Customer(m.BaseHashModel): + class Customer(m.BaseHashModel, index=True): id: int = Field(primary_key=True, index=True) first_name: str = Field(primary_key=True, index=True) last_name: str @@ -775,13 +775,13 @@ class Customer(m.BaseHashModel): @py_test_mark_asyncio async def test_primary_pk_exists(m): - class Customer1(m.BaseHashModel): + class Customer1(m.BaseHashModel, index=True): id: int first_name: str last_name: str bio: Optional[str] - class Customer2(m.BaseHashModel): + class Customer2(m.BaseHashModel, index=True): id: int = Field(primary_key=True, index=True) first_name: str last_name: str diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 4ae1cf54..32246135 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1329,6 +1329,10 @@ class Child(Model): pass assert issubclass(Child, Model) + + child = Child(name="John") + + assert child.name == "John" @py_test_mark_asyncio From d5520ed686f8e2de1e586633fd499be16c0ebcfd Mon Sep 17 00:00:00 2001 From: Seth Buffenbarger Date: Mon, 28 Apr 2025 15:08:57 -0400 Subject: [PATCH 17/17] fix linting errors --- aredis_om/model/model.py | 10 ++++++---- tests/test_json_model.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 1190b06f..9ac3ee9d 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -1371,16 +1371,18 @@ def outer_type_or_annotation(field: FieldInfo): else: return field.annotation.__args__[0] # type: ignore -def should_index_field(field_info: FieldInfo) -> bool: + +def should_index_field(field_info: PydanticFieldInfo) -> bool: # for vector, full text search, and sortable fields, we always have to index # We could require the user to set index=True, but that would be a breaking change return ( getattr(field_info, "index", False) is True - or getattr(field_info, "vector_options", None) is not None - or getattr(field_info, "full_text_search", False) is True - or getattr(field_info, "sortable", False) is True + or getattr(field_info, "vector_options", None) is not None + or getattr(field_info, "full_text_search", False) is True + or getattr(field_info, "sortable", False) is True ) + class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field( # Indexing for backwards compatibility, we might not want this in the future diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 32246135..d6428e18 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1329,7 +1329,7 @@ class Child(Model): pass assert issubclass(Child, Model) - + child = Child(name="John") assert child.name == "John"