diff --git a/aredis_om/__init__.py b/aredis_om/__init__.py index 7aa699dd..813e3b04 100644 --- a/aredis_om/__init__.py +++ b/aredis_om/__init__.py @@ -8,11 +8,11 @@ FindQuery, HashModel, JsonModel, - VectorFieldOptions, KNNExpression, NotFoundError, QueryNotSupportedError, QuerySyntaxError, RedisModel, RedisModelError, + VectorFieldOptions, ) diff --git a/aredis_om/model/__init__.py b/aredis_om/model/__init__.py index b9ecf36f..fcdce89d 100644 --- a/aredis_om/model/__init__.py +++ b/aredis_om/model/__init__.py @@ -4,8 +4,8 @@ Field, HashModel, JsonModel, - VectorFieldOptions, KNNExpression, NotFoundError, RedisModel, + VectorFieldOptions, ) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index a4c6b9e7..3644b63b 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -25,6 +25,8 @@ ) from more_itertools import ichunked +from redis import Redis +from redis.asyncio import Redis as RedisAsync from redis.commands.json.path import Path from redis.exceptions import ResponseError from typing_extensions import Protocol, get_args, get_origin @@ -1255,9 +1257,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 base_meta, "primary_key_pattern", "{pk}" ) if not getattr(new_class._meta, "database", None): - new_class._meta.database = getattr( - base_meta, "database", get_redis_connection() - ) + new_class._meta.database = getattr(base_meta, "database", None) if not getattr(new_class._meta, "encoding", None): new_class._meta.encoding = getattr(base_meta, "encoding") if not getattr(new_class._meta, "primary_key_creator_cls", None): @@ -1282,6 +1282,7 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901 class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta): pk: Optional[str] = Field(default=None, primary_key=True) + _conn: Optional[Union[Redis, RedisAsync]] = None Meta = DefaultMeta @@ -1370,7 +1371,19 @@ def make_primary_key(cls, pk: Any): @classmethod def db(cls): - return cls._meta.database + if not cls._conn: + conn = ( + cls._meta.database() + if callable(cls._meta.database) + else cls._meta.database or get_redis_connection() + ) + if not has_redis_json(conn): + log.error( + "Your Redis instance does not have the RedisJson module " + "loaded. JsonModel depends on RedisJson." + ) + cls._conn = conn + return cls._conn @classmethod def find( @@ -1674,14 +1687,6 @@ def __init_subclass__(cls, **kwargs): # Generate the RediSearch schema once to validate fields. cls.redisearch_schema() - def __init__(self, *args, **kwargs): - if not has_redis_json(self.db()): - log.error( - "Your Redis instance does not have the RedisJson module " - "loaded. JsonModel depends on RedisJson." - ) - super().__init__(*args, **kwargs) - async def save( self: "Model", pipeline: Optional[redis.client.Pipeline] = None ) -> "Model": diff --git a/tests/conftest.py b/tests/conftest.py index 9f067a38..ef19afdd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,8 @@ import pytest -from aredis_om import get_redis_connection +from aredis_om import RedisModel, get_redis_connection +from aredis_om.model.model import DefaultMeta, model_registry TEST_PREFIX = "redis-om:testing" @@ -59,3 +60,13 @@ def cleanup_keys(request): # Delete keys only once if conn.decr(once_key) == 0: _delete_test_keys(TEST_PREFIX, conn) + + +@pytest.fixture(autouse=True) +def reset_meta(): + yield + RedisModel.Meta.database = DefaultMeta + if hasattr(RedisModel, "_meta"): + del RedisModel._meta + RedisModel._conn = None + model_registry.clear() diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 8fec6c0a..838604ff 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -10,6 +10,8 @@ import pytest import pytest_asyncio +from redis import ConnectionError, Redis +from redis.asyncio import Redis as AsyncRedis from aredis_om import ( EmbeddedJsonModel, @@ -849,3 +851,93 @@ async def test_count(members, m): m.Member.first_name == "Kim", m.Member.last_name == "Brookins" ).count() assert actual_count == 1 + + +@py_test_mark_asyncio +async def test_default_connection_not_configured_at_class_definition_time(): + class MyJsonModel(JsonModel): + a_field: int + + assert MyJsonModel._meta.database is None + + +@py_test_mark_asyncio +async def test_default_connection_configured_and_opened_at_usage_time(): + class MyJsonModel(JsonModel): + a_field: int + + obj = MyJsonModel(a_field=42) + await obj.save() + + assert MyJsonModel._meta.database is None + assert isinstance(MyJsonModel._conn, (Redis, AsyncRedis)) + assert MyJsonModel._conn.connection_pool.connection_kwargs["host"] == "localhost" + + +@py_test_mark_asyncio +async def test_custom_connection_configured_at_class_definition_time(): + class MyJsonModel(JsonModel): + a_field: int + + class Meta: + database = Redis(host="10.20.30.40", port=1234) + + assert isinstance(MyJsonModel._meta.database, Redis) + assert ( + MyJsonModel._meta.database.connection_pool.connection_kwargs["host"] + == "10.20.30.40" + ) + assert MyJsonModel._meta.database.connection_pool.connection_kwargs["port"] == 1234 + + +@py_test_mark_asyncio +async def test_custom_connection_opened_at_usage_time(): + class MyJsonModel(JsonModel): + a_field: int + + class Meta: + database = Redis(host="10.20.30.40", port=5678) + + obj = MyJsonModel(a_field=42) + with pytest.raises(ConnectionError, match="connecting to 10.20.30.40:5678"): + await obj.save() + + +@py_test_mark_asyncio +async def test_lazy_connection_configured_and_opened_at_usage_time(): + def my_connection(): + return Redis(host="10.20.30.40", port=9012) + + class MyJsonModel(JsonModel): + a_field: int + + class Meta: + database = my_connection + + obj = MyJsonModel(a_field=42) + + assert not isinstance(MyJsonModel._meta.database, Redis) + assert callable(MyJsonModel._meta.database) + assert MyJsonModel._conn is None + + with pytest.raises(ConnectionError, match="connecting to 10.20.30.40:9012"): + await obj.save() + + +@py_test_mark_asyncio +async def test_lazy_connection_cached(redis): + def my_connection(): + return redis + + class MyJsonModel(JsonModel): + a_field: int + + class Meta: + database = my_connection + + obj = MyJsonModel(a_field=42) + await obj.save() + + assert isinstance(MyJsonModel._conn, (Redis, AsyncRedis)) + assert MyJsonModel.db() is MyJsonModel._conn + assert MyJsonModel.db() is MyJsonModel.db()