diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 67769cd..5e2aeaa 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -2178,7 +2178,11 @@ def schema_for_type( # a proper type, we can pull the type information from the origin of the first argument. if not isinstance(typ, type): type_args = typing_get_args(field_info.annotation) - typ = type_args[0].__origin__ + typ = ( + getattr(type_args[0], "__origin__", type_args[0]) + if type_args + else typ + ) # TODO: GEO field if is_vector and vector_options: diff --git a/tests/test_json_model.py b/tests/test_json_model.py index a3e4d40..404d59c 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -13,6 +13,7 @@ import pytest import pytest_asyncio from more_itertools.more import first +from redis.exceptions import ResponseError from aredis_om import ( EmbeddedJsonModel, @@ -23,6 +24,7 @@ QueryNotSupportedError, RedisModel, RedisModelError, + VectorFieldOptions, ) # We need to run this check as sync code (during tests) even in async mode @@ -1235,3 +1237,82 @@ class Game(JsonModel): ) print(q.query) assert q.query == "(@player1_username:{username})| (@player2_username:{username})" + + +@py_test_mark_asyncio +def test_vector_field_definition(redis): + """ + Test the definition and behavior of a vector field in a JsonModel. + This test verifies: + 1. The model schema includes "VECTOR" for a vector field with specified options. + 2. Instances with vector fields can be saved and retrieved accurately. + 3. Vector field values remain consistent after persistence. + + Args: + redis: Redis connection fixture for testing. + """ + + class Group(JsonModel): + articles: List[str] + tender_text: str = Field(index=False) + tender_embedding: List[float] = Field( + index=True, + vector_options=VectorFieldOptions( + algorithm=VectorFieldOptions.ALGORITHM.FLAT, + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=3, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, + ), + ) + + schema = Group.redisearch_schema() + assert "VECTOR" in schema + + group = Group( + articles=["article_1", "article_2"], + tender_text="Sample text", + tender_embedding=[0.1, 0.2, 0.3], + ) + group.save() + + retrieved_group = await Group.get(group.pk) + assert retrieved_group.tender_embedding == [0.1, 0.2, 0.3] + + retrieved_group = Group.get(group.pk) + assert retrieved_group.tender_embedding == [0.1, 0.2, 0.3] + + +def test_vector_field_schema_debug(redis): + """ + Test and debug the schema definition for a vector field in a JsonModel. + + This test ensures: + 1. The schema for a vector field is generated correctly. + 2. No syntax errors occur when saving a model instance. + 3. The Redis schema syntax is valid and debugged if issues arise. + + Steps: + - Define a `TestModel` with a vector field `embedding`. + - Attempt to save an instance and print the schema. + - Handle and fail gracefully if Redis raises a ResponseError. + + Args: + redis: Redis connection fixture for testing. + """ + + class TestModel(JsonModel): + embedding: List[float] = Field( + index=True, + vector_options=VectorFieldOptions( + algorithm=VectorFieldOptions.ALGORITHM.FLAT, + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=3, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, + ), + ) + + try: + TestModel(embedding=[0.1, 0.2, 0.3]).save() + print(TestModel.redisearch_schema()) + except ResponseError as e: + pytest.fail(f"Redis rejected the schema with error: {e}")