From 629d3db7d346ede9785f7ac9f9c1a3deb6a58939 Mon Sep 17 00:00:00 2001 From: Gabriel Erzse Date: Fri, 19 Jul 2024 12:40:38 +0300 Subject: [PATCH 01/78] Restructure client side caching code Right now the client side caching code is implemented mostly on the level of Connections, which is too low. We need to have a shared cache across several connections. Move the cache implementation higher, while trying to encapsulate it better, into a `CacheMixin` class. This is work in progress, many details still need to be taken care of! --- redis/_cache.py | 385 ----------------- redis/_parsers/resp3.py | 48 +-- redis/asyncio/client.py | 72 +--- redis/asyncio/cluster.py | 83 +--- redis/asyncio/connection.py | 119 +---- redis/asyncio/sentinel.py | 1 - redis/cache.py | 143 ++++++ redis/client.py | 82 ++-- redis/cluster.py | 96 ++--- redis/connection.py | 130 +----- redis/sentinel.py | 1 - requirements.txt | 1 + tests/conftest.py | 1 - tests/test_asyncio/conftest.py | 1 - tests/test_asyncio/test_cache.py | 408 ------------------ tests/test_asyncio/test_cluster.py | 2 - tests/test_asyncio/test_connection.py | 1 - tests/test_cache.py | 598 +------------------------- tests/test_cluster.py | 3 - 19 files changed, 287 insertions(+), 1888 deletions(-) delete mode 100644 redis/_cache.py create mode 100644 redis/cache.py delete mode 100644 tests/test_asyncio/test_cache.py diff --git a/redis/_cache.py b/redis/_cache.py deleted file mode 100644 index 90288383d6..0000000000 --- a/redis/_cache.py +++ /dev/null @@ -1,385 +0,0 @@ -import copy -import random -import time -from abc import ABC, abstractmethod -from collections import OrderedDict, defaultdict -from enum import Enum -from typing import List, Sequence, Union - -from redis.typing import KeyT, ResponseT - - -class EvictionPolicy(Enum): - LRU = "lru" - LFU = "lfu" - RANDOM = "random" - - -DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU - -DEFAULT_DENY_LIST = [ - "BF.CARD", - "BF.DEBUG", - "BF.EXISTS", - "BF.INFO", - "BF.MEXISTS", - "BF.SCANDUMP", - "CF.COMPACT", - "CF.COUNT", - "CF.DEBUG", - "CF.EXISTS", - "CF.INFO", - "CF.MEXISTS", - "CF.SCANDUMP", - "CMS.INFO", - "CMS.QUERY", - "DUMP", - "EXPIRETIME", - "FT.AGGREGATE", - "FT.ALIASADD", - "FT.ALIASDEL", - "FT.ALIASUPDATE", - "FT.CURSOR", - "FT.EXPLAIN", - "FT.EXPLAINCLI", - "FT.GET", - "FT.INFO", - "FT.MGET", - "FT.PROFILE", - "FT.SEARCH", - "FT.SPELLCHECK", - "FT.SUGGET", - "FT.SUGLEN", - "FT.SYNDUMP", - "FT.TAGVALS", - "FT._ALIASADDIFNX", - "FT._ALIASDELIFX", - "HRANDFIELD", - "JSON.DEBUG", - "PEXPIRETIME", - "PFCOUNT", - "PTTL", - "SRANDMEMBER", - "TDIGEST.BYRANK", - "TDIGEST.BYREVRANK", - "TDIGEST.CDF", - "TDIGEST.INFO", - "TDIGEST.MAX", - "TDIGEST.MIN", - "TDIGEST.QUANTILE", - "TDIGEST.RANK", - "TDIGEST.REVRANK", - "TDIGEST.TRIMMED_MEAN", - "TOPK.INFO", - "TOPK.LIST", - "TOPK.QUERY", - "TOUCH", - "TTL", -] - -DEFAULT_ALLOW_LIST = [ - "BITCOUNT", - "BITFIELD_RO", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUSBYMEMBER_RO", - "GEORADIUS_RO", - "GEOSEARCH", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "JSON.ARRINDEX", - "JSON.ARRLEN", - "JSON.GET", - "JSON.MGET", - "JSON.OBJKEYS", - "JSON.OBJLEN", - "JSON.RESP", - "JSON.STRLEN", - "JSON.TYPE", - "LCS", - "LINDEX", - "LLEN", - "LPOS", - "LRANGE", - "MGET", - "SCARD", - "SDIFF", - "SINTER", - "SINTERCARD", - "SISMEMBER", - "SMEMBERS", - "SMISMEMBER", - "SORT_RO", - "STRLEN", - "SUBSTR", - "SUNION", - "TS.GET", - "TS.INFO", - "TS.RANGE", - "TS.REVRANGE", - "TYPE", - "XLEN", - "XPENDING", - "XRANGE", - "XREAD", - "XREVRANGE", - "ZCARD", - "ZCOUNT", - "ZDIFF", - "ZINTER", - "ZINTERCARD", - "ZLEXCOUNT", - "ZMSCORE", - "ZRANGE", - "ZRANGEBYLEX", - "ZRANGEBYSCORE", - "ZRANK", - "ZREVRANGE", - "ZREVRANGEBYLEX", - "ZREVRANGEBYSCORE", - "ZREVRANK", - "ZSCORE", - "ZUNION", -] - -_RESPONSE = "response" -_KEYS = "keys" -_CTIME = "ctime" -_ACCESS_COUNT = "access_count" - - -class AbstractCache(ABC): - """ - An abstract base class for client caching implementations. - If you want to implement your own cache you must support these methods. - """ - - @abstractmethod - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - pass - - @abstractmethod - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - pass - - @abstractmethod - def delete_command(self, command: Union[str, Sequence[str]]): - pass - - @abstractmethod - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - pass - - @abstractmethod - def flush(self): - pass - - @abstractmethod - def invalidate_key(self, key: KeyT): - pass - - -class _LocalCache(AbstractCache): - """ - A caching mechanism for storing redis commands and their responses. - - Args: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full. - - Attributes: - max_size (int): The maximum number of commands to be stored in the cache. - ttl (int): The time-to-live for each command in seconds. - eviction_policy (EvictionPolicy): The eviction policy used for cache management. - cache (OrderedDict): The ordered dictionary to store commands and their metadata. - key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key. - commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa - """ - - def __init__( - self, - max_size: int = 10000, - ttl: int = 0, - eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - ): - self.max_size = max_size - self.ttl = ttl - self.eviction_policy = eviction_policy - self.cache = OrderedDict() - self.key_commands_map = defaultdict(set) - self.commands_ttl_list = [] - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - """ - Set a redis command and its response in the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - response (ResponseT): The response associated with the command. - keys_in_command (List[KeyT]): The list of keys used in the command. - """ - if len(self.cache) >= self.max_size: - self._evict() - self.cache[command] = { - _RESPONSE: response, - _KEYS: keys_in_command, - _CTIME: time.monotonic(), - _ACCESS_COUNT: 0, # Used only for LFU - } - self._update_key_commands_map(keys_in_command, command) - self.commands_ttl_list.append(command) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - """ - Get the response for a redis command from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa - """ - if command in self.cache: - if self._is_expired(command): - self.delete_command(command) - return - self._update_access(command) - return copy.deepcopy(self.cache[command]["response"]) - - def delete_command(self, command: Union[str, Sequence[str]]): - """ - Delete a redis command and its metadata from the cache. - - Args: - command (Union[str, Sequence[str]]): The redis command to be deleted. - """ - if command in self.cache: - keys_in_command = self.cache[command].get("keys") - self._del_key_commands_map(keys_in_command, command) - self.commands_ttl_list.remove(command) - del self.cache[command] - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - """ - Delete multiple commands and their metadata from the cache. - - Args: - commands (List[Union[str, Sequence[str]]]): The list of commands to be - deleted. - """ - for command in commands: - self.delete_command(command) - - def flush(self): - """Clear the entire cache, removing all redis commands and metadata.""" - self.cache.clear() - self.key_commands_map.clear() - self.commands_ttl_list = [] - - def _is_expired(self, command: Union[str, Sequence[str]]) -> bool: - """ - Check if a redis command has expired based on its time-to-live. - - Args: - command (Union[str, Sequence[str]]): The redis command. - - Returns: - bool: True if the command has expired, False otherwise. - """ - if self.ttl == 0: - return False - return time.monotonic() - self.cache[command]["ctime"] > self.ttl - - def _update_access(self, command: Union[str, Sequence[str]]): - """ - Update the access information for a redis command based on the eviction policy. - - Args: - command (Union[str, Sequence[str]]): The redis command. - """ - if self.eviction_policy == EvictionPolicy.LRU: - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.LFU: - self.cache[command]["access_count"] = ( - self.cache.get(command, {}).get("access_count", 0) + 1 - ) - self.cache.move_to_end(command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - pass # Random eviction doesn't require updates - - def _evict(self): - """Evict a redis command from the cache based on the eviction policy.""" - if self._is_expired(self.commands_ttl_list[0]): - self.delete_command(self.commands_ttl_list[0]) - elif self.eviction_policy == EvictionPolicy.LRU: - self.cache.popitem(last=False) - elif self.eviction_policy == EvictionPolicy.LFU: - min_access_command = min( - self.cache, key=lambda k: self.cache[k].get("access_count", 0) - ) - self.cache.pop(min_access_command) - elif self.eviction_policy == EvictionPolicy.RANDOM: - random_command = random.choice(list(self.cache.keys())) - self.cache.pop(random_command) - - def _update_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Update the key_commands_map with command that uses the keys. - - Args: - keys (List[KeyT]): The list of keys used in the command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].add(command) - - def _del_key_commands_map( - self, keys: List[KeyT], command: Union[str, Sequence[str]] - ): - """ - Remove a redis command from the key_commands_map. - - Args: - keys (List[KeyT]): The list of keys used in the redis command. - command (Union[str, Sequence[str]]): The redis command. - """ - for key in keys: - self.key_commands_map[key].remove(command) - - def invalidate_key(self, key: KeyT): - """ - Invalidate (delete) all redis commands associated with a specific key. - - Args: - key (KeyT): The key to be invalidated. - """ - if key not in self.key_commands_map: - return - commands = list(self.key_commands_map[key]) - for command in commands: - self.delete_command(command) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index cc210b9df5..0e0a6655d2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -120,6 +120,12 @@ def _read_response(self, disable_decoding=False, push_request=False): response = self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -128,19 +134,10 @@ def _read_response(self, disable_decoding=False, push_request=False): return response def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -155,7 +152,7 @@ def __init__(self, socket_read_size): self.pubsub_push_handler_func = self.handle_pubsub_push_response self.invalidation_push_handler_func = None - def handle_pubsub_push_response(self, response): + async def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response @@ -267,6 +264,12 @@ async def _read_response( response = await self.handle_push_response( response, disable_decoding, push_request ) + if not push_request: + return await self._read_response( + disable_decoding=disable_decoding, push_request=push_request + ) + else: + return response else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -275,19 +278,10 @@ async def _read_response( return response async def handle_push_response(self, response, disable_decoding, push_request): - if response[0] in _INVALIDATION_MESSAGE: - if self.invalidation_push_handler_func: - res = self.invalidation_push_handler_func(response) - else: - res = None - else: - res = self.pubsub_push_handler_func(response) - if not push_request: - return await self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return res + if response[0] not in _INVALIDATION_MESSAGE: + return await self.pubsub_push_handler_func(response) + if self.invalidation_push_handler_func: + return await self.invalidation_push_handler_func(response) def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 1845b7252f..5d93c83b12 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -26,12 +26,6 @@ cast, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -239,13 +233,6 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Redis client. @@ -295,13 +282,6 @@ def __init__( "lib_version": lib_version, "redis_connect_func": redis_connect_func, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -624,31 +604,22 @@ async def _disconnect_raise(self, conn: Connection, error: Exception): async def execute_command(self, *args, **options): """Execute a command and return a parsed response""" await self.initialize() - command_name = args[0] - keys = options.pop("keys", None) # keys are used only for client side caching pool = self.connection_pool + command_name = args[0] conn = self.connection or await pool.get_connection(command_name, **options) - response_from_cache = await conn._get_from_local_cache(args) + + if self.single_connection_client: + await self._single_conn_lock.acquire() try: - if response_from_cache is not None: - return response_from_cache - else: - try: - if self.single_connection_client: - await self._single_conn_lock.acquire() - response = await conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response - finally: - if self.single_connection_client: - self._single_conn_lock.release() + return await conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: + if self.single_connection_client: + self._single_conn_lock.release() if not self.connection: await pool.release(conn) @@ -677,24 +648,6 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - StrictRedis = Redis @@ -1331,7 +1284,6 @@ def multi(self): def execute_command( self, *args, **kwargs ) -> Union["Pipeline", Awaitable["Pipeline"]]: - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 40b2948a7f..cbceccf401 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -19,12 +19,6 @@ Union, ) -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) from redis._parsers import AsyncCommandsParser, Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -276,13 +270,6 @@ def __init__( ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 100, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ) -> None: if db: raise RedisClusterException( @@ -326,14 +313,6 @@ def __init__( "socket_timeout": socket_timeout, "retry": retry, "protocol": protocol, - # Client cache related kwargs - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } if ssl: @@ -938,18 +917,6 @@ def lock( thread_local=thread_local, ) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: """ @@ -1076,25 +1043,16 @@ async def parse_response( async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Acquire connection connection = self.acquire_connection() - keys = kwargs.pop("keys", None) - response_from_cache = await connection._get_from_local_cache(args) - if response_from_cache is not None: - self._free.append(connection) - return response_from_cache - else: - # Execute command - await connection.send_packed_command(connection.pack_command(*args), False) + # Execute command + await connection.send_packed_command(connection.pack_command(*args), False) - # Read response - try: - response = await self.parse_response(connection, args[0], **kwargs) - if keys: - connection._add_to_local_cache(args, response, keys) - return response - finally: - # Release connection - self._free.append(connection) + # Read response + try: + return await self.parse_response(connection, args[0], **kwargs) + finally: + # Release connection + self._free.append(connection) async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: # Acquire connection @@ -1121,18 +1079,6 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool: return ret - def flush_cache(self): - for connection in self._connections: - connection.flush_cache() - - def delete_command_from_cache(self, command): - for connection in self._connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for connection in self._connections: - connection.invalidate_key_from_cache(key) - class NodesManager: __slots__ = ( @@ -1416,18 +1362,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ @@ -1516,7 +1450,6 @@ def execute_command( or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - kwargs.pop("keys", None) # the keys are used only for client side caching self._command_stack.append( PipelineCommand(len(self._command_stack), *args, **kwargs) ) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2ac6637986..ddbd22c95d 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -49,16 +49,9 @@ ResponseError, TimeoutError, ) -from redis.typing import EncodableT, KeysT, ResponseT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes -from .._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, -) from .._parsers import ( BaseParser, Encoder, @@ -121,9 +114,6 @@ class AbstractConnection: "encoder", "ssl_context", "protocol", - "client_cache", - "cache_deny_list", - "cache_allow_list", "_reader", "_writer", "_parser", @@ -158,13 +148,6 @@ def __init__( encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): if (username or password) and credential_provider is not None: raise DataError( @@ -222,18 +205,6 @@ def __init__( if p < 2 or p > 3: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __del__(self, _warnings: Any = warnings): # For some reason, the individual streams don't get properly garbage @@ -425,11 +396,6 @@ async def on_connect(self) -> None: # if a database is specified, switch to it. Also pipeline this if self.db: await self.send_command("SELECT", self.db) - # if client caching is enabled, start tracking - if self.client_cache: - await self.send_command("CLIENT", "TRACKING", "ON") - await self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -464,9 +430,6 @@ async def disconnect(self, nowait: bool = False) -> None: raise TimeoutError( f"Timed out closing connection after {self.socket_connect_timeout}" ) from None - finally: - if self.client_cache: - self.client_cache.flush() async def _send_ping(self): """Send PING, expect PONG in return""" @@ -688,60 +651,9 @@ def _socket_is_empty(self): """Check if the socket is empty""" return len(self._reader._buffer) == 0 - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - async def _get_from_local_cache(self, command: str): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None + async def process_invalidation_messages(self): while not self._socket_is_empty(): await self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Tuple[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() - - def delete_command_from_cache(self, command): - if self.client_cache: - self.client_cache.delete_command(command) - - def invalidate_key_from_cache(self, key): - if self.client_cache: - self.client_cache.invalidate_key(key) class Connection(AbstractConnection): @@ -1177,18 +1089,12 @@ def make_connection(self): async def ensure_connection(self, connection: AbstractConnection): """Ensure that the connection object is connected and valid""" await connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if ( - await connection.can_read_destructive() - and connection.client_cache is None - ): + if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None except (ConnectionError, OSError): await connection.disconnect() @@ -1235,21 +1141,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index 6fd233adc8..5d4608ed2f 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -225,7 +225,6 @@ async def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/redis/cache.py b/redis/cache.py new file mode 100644 index 0000000000..2a50565554 --- /dev/null +++ b/redis/cache.py @@ -0,0 +1,143 @@ +from typing import Callable, TypeVar, Any, NoReturn, List, Union +from typing import Optional + +from cachetools import TTLCache, Cache, LRUCache +from cachetools.keys import hashkey + +from redis.typing import ResponseT + +T = TypeVar('T') + + +def ensure_string(key): + if isinstance(key, bytes): + return key.decode('utf-8') + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") + + +class CacheMixin: + def __init__(self, + use_cache: bool, + connection_pool: "ConnectionPool", + cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, + ) -> None: + self.use_cache = use_cache + if not use_cache: + return + if cache is not None: + self.cache = cache + else: + self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl) + self.keys_mapping = LRUCache(maxsize=10000) + self.wrap_connection_pool(connection_pool) + self.connections = [] + + def cached_call(self, + func: Callable[..., ResponseT], + *args, + **options) -> ResponseT: + if not self.use_cache: + return func(*args, **options) + + print(f'Cached call with args {args} and options {options}') + + keys = None + if 'keys' in options: + keys = options['keys'] + if not isinstance(keys, list): + raise TypeError("Cache keys must be a list.") + if not keys: + return func(*args, **options) + print(f'keys {keys}') + + cache_key = hashkey(*args) + + for conn in self.connections: + conn.process_invalidation_messages() + + for key in keys: + if key in self.keys_mapping: + if cache_key not in self.keys_mapping[key]: + self.keys_mapping[key].append(cache_key) + else: + self.keys_mapping[key] = [cache_key] + + if cache_key in self.cache: + result = self.cache[cache_key] + print(f'Cached call for {args} yields cached result {result}') + return result + + result = func(*args, **options) + self.cache[cache_key] = result + print(f'Cached call for {args} yields computed result {result}') + return result + + def get_cache_entry(self, *args: Any) -> Any: + cache_key = hashkey(*args) + return self.cache.get(cache_key, None) + + def invalidate_cache_entry(self, *args: Any) -> None: + cache_key = hashkey(*args) + if cache_key in self.cache: + self.cache.pop(cache_key) + + def wrap_connection_pool(self, connection_pool: "ConnectionPool"): + if not self.use_cache: + return + if connection_pool is None: + return + original_maker = connection_pool.make_connection + connection_pool.make_connection = lambda: self._make_connection(original_maker) + + def _make_connection(self, original_maker: Callable[[], "Connection"]): + conn = original_maker() + original_disconnect = conn.disconnect + conn.disconnect = lambda: self._wrapped_disconnect(conn, original_disconnect) + self.add_connection(conn) + return conn + + def _wrapped_disconnect(self, connection: "Connection", + original_disconnect: Callable[[], NoReturn]): + original_disconnect() + self.remove_connection(connection) + + def add_connection(self, conn): + print(f'Tracking connection {conn} {id(conn)}') + conn.register_connect_callback(self._on_connect) + self.connections.append(conn) + + def _on_connect(self, conn): + conn.send_command("CLIENT", "TRACKING", "ON") + response = conn.read_response() + print(f"Client tracking response {response}") + conn._parser.set_invalidation_push_handler(self._cache_invalidation_process) + + def _cache_invalidation_process( + self, data: List[Union[str, Optional[List[str]]]] + ) -> None: + """ + Invalidate (delete) all redis commands associated with a specific key. + `data` is a list of strings, where the first string is the invalidation message + and the second string is the list of keys to invalidate. + (if the list of keys is None, then all keys are invalidated) + """ + print(f'Invalidation {data}') + if data[1] is None: + self.cache.clear() + else: + for key in data[1]: + normalized_key = ensure_string(key) + print(f'Invalidating normalized key {normalized_key}') + if normalized_key in self.keys_mapping: + for cache_key in self.keys_mapping[normalized_key]: + print(f'Invalidating cache key {cache_key}') + self.cache.pop(cache_key) + + def remove_connection(self, conn): + print(f'Untracking connection {conn} {id(conn)}') + self.connections.remove(conn) diff --git a/redis/client.py b/redis/client.py index b7a1f88d92..6d3a11a2ba 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,12 +6,8 @@ from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union -from redis._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, -) +from cachetools import Cache + from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, @@ -19,6 +15,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import CacheMixin from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -89,7 +86,7 @@ class AbstractRedis: pass -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands, CacheMixin): """ Implementation of the Redis protocol. @@ -147,10 +144,12 @@ class initializer. In the case of conflicting arguments, querystring """ single_connection_client = kwargs.pop("single_connection_client", False) + use_cache = kwargs.pop("use_cache", False) connection_pool = ConnectionPool.from_url(url, **kwargs) client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, + use_cache=use_cache ) client.auto_close_connection_pool = True return client @@ -216,13 +215,10 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, + use_cache: bool = False, + cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, ) -> None: """ Initialize a new Redis client. @@ -274,13 +270,6 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, - "cache_enabled": cache_enabled, - "client_cache": client_cache, - "cache_max_size": cache_max_size, - "cache_ttl": cache_ttl, - "cache_policy": cache_policy, - "cache_deny_list": cache_deny_list, - "cache_allow_list": cache_allow_list, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -328,6 +317,11 @@ def __init__( self.auto_close_connection_pool = False self.connection_pool = connection_pool + + if use_cache and self.connection_pool.get_protocol() not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + CacheMixin.__init__(self, use_cache, self.connection_pool, cache, cache_size, cache_ttl) + self.connection = None if single_connection_client: self.connection = self.connection_pool.get_connection("_") @@ -559,25 +553,22 @@ def _disconnect_raise(self, conn, error): # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): + if self.use_cache: + return self.cached_call(self._execute_command, *args, **options) + return self._execute_command(*args, **options) + + def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" - command_name = args[0] - keys = options.pop("keys", None) pool = self.connection_pool + command_name = args[0] conn = self.connection or pool.get_connection(command_name, **options) - response_from_cache = conn._get_from_local_cache(args) try: - if response_from_cache is not None: - return response_from_cache - else: - response = conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_raise(conn, error), - ) - if keys: - conn._add_to_local_cache(args, response, keys) - return response + return conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_raise(conn, error), + ) finally: if not self.connection: pool.release(conn) @@ -602,24 +593,6 @@ def parse_response(self, connection, command_name, **options): return self.response_callbacks[command_name](response, **options) return response - def flush_cache(self): - if self.connection: - self.connection.flush_cache() - else: - self.connection_pool.flush_cache() - - def delete_command_from_cache(self, command): - if self.connection: - self.connection.delete_command_from_cache(command) - else: - self.connection_pool.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.connection: - self.connection.invalidate_key_from_cache(key) - else: - self.connection_pool.invalidate_key_from_cache(key) - StrictRedis = Redis @@ -1314,7 +1287,6 @@ def multi(self) -> None: self.explicit_transaction = True def execute_command(self, *args, **kwargs): - kwargs.pop("keys", None) # the keys are used only for client side caching if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index be7685e9a1..39e8c4b9ea 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,9 +6,12 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from cachetools import Cache + from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff +from redis.cache import CacheMixin from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -167,13 +170,7 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", - "cache_enabled", - "client_cache", - "cache_max_size", - "cache_ttl", - "cache_policy", - "cache_deny_list", - "cache_allow_list", + "use_cache", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -449,7 +446,7 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: self.nodes_manager.default_node = random.choice(replicas) -class RedisCluster(AbstractRedisCluster, RedisClusterCommands): +class RedisCluster(AbstractRedisCluster, RedisClusterCommands, CacheMixin): @classmethod def from_url(cls, url, **kwargs): """ @@ -507,6 +504,7 @@ def __init__( dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + use_cache: Optional[bool] = False, **kwargs, ): """ @@ -642,6 +640,7 @@ def __init__( require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, + use_cache=use_cache, **kwargs, ) @@ -649,6 +648,12 @@ def __init__( self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) + + protocol = kwargs.get("protocol", None) + if use_cache and protocol not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + CacheMixin.__init__(self, use_cache, None) + self.commands_parser = CommandsParser(self) self._lock = threading.Lock() @@ -1051,7 +1056,12 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def execute_command(self, *args, **kwargs): + def execute_command(self, *args, **options): + if self.use_cache: + return self.cached_call(self._execute_command, *args, **options) + return self._internal_execute_command(*args, **options) + + def _internal_execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. @@ -1125,7 +1135,6 @@ def _execute_command(self, target_node, *args, **kwargs): """ Send a command to a node in the cluster """ - keys = kwargs.pop("keys", None) command = args[0] redis_node = None connection = None @@ -1154,19 +1163,13 @@ def _execute_command(self, target_node, *args, **kwargs): connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - response_from_cache = connection._get_from_local_cache(args) - if response_from_cache is not None: - return response_from_cache - else: - connection.send_command(*args) - response = redis_node.parse_response(connection, command, **kwargs) - if command in self.cluster_response_callbacks: - response = self.cluster_response_callbacks[command]( - response, **kwargs - ) - if keys: - connection._add_to_local_cache(args, response, keys) - return response + connection.send_command(*args) + response = redis_node.parse_response(connection, command, **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs + ) + return response except AuthenticationError: raise except (ConnectionError, TimeoutError) as e: @@ -1266,18 +1269,6 @@ def load_external_module(self, funcname, func): """ setattr(self, funcname, func) - def flush_cache(self): - if self.nodes_manager: - self.nodes_manager.flush_cache() - - def delete_command_from_cache(self, command): - if self.nodes_manager: - self.nodes_manager.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.nodes_manager: - self.nodes_manager.invalidate_key_from_cache(key) - class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1306,18 +1297,6 @@ def __del__(self): if self.redis_connection is not None: self.redis_connection.close() - def flush_cache(self): - if self.redis_connection is not None: - self.redis_connection.flush_cache() - - def delete_command_from_cache(self, command): - if self.redis_connection is not None: - self.redis_connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - if self.redis_connection is not None: - self.redis_connection.invalidate_key_from_cache(key) - class LoadBalancer: """ @@ -1338,7 +1317,7 @@ def reset(self) -> None: self.primary_to_idx.clear() -class NodesManager: +class NodesManager(CacheMixin): def __init__( self, startup_nodes, @@ -1348,6 +1327,8 @@ def __init__( dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, + use_cache: Optional[bool] = False, + cache: Optional[Cache] = None, **kwargs, ): self.nodes_cache = {} @@ -1360,12 +1341,14 @@ def __init__( self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap + self.use_cache = use_cache self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() if lock is None: lock = threading.Lock() self._lock = lock + CacheMixin.__init__(self, use_cache, None, cache) self.initialize() def get_node(self, host=None, port=None, node_name=None): @@ -1503,9 +1486,9 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - r = Redis(connection_pool=self.connection_pool_class(**kwargs)) + r = Redis(connection_pool=self.connection_pool_class(**kwargs), use_cache=self.use_cache, cache=self.cache) else: - r = Redis(host=host, port=port, **kwargs) + r = Redis(host=host, port=port, use_cache=self.use_cache, cache=self.cache, **kwargs) return r def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): @@ -1681,18 +1664,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def flush_cache(self): - for node in self.nodes_cache.values(): - node.flush_cache() - - def delete_command_from_cache(self, command): - for node in self.nodes_cache.values(): - node.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key): - for node in self.nodes_cache.values(): - node.invalidate_key_from_cache(key) - class ClusterPubSub(PubSub): """ @@ -2008,7 +1979,6 @@ def execute_command(self, *args, **kwargs): """ Wrapper function for pipeline_execute_command """ - kwargs.pop("keys", None) # the keys are used only for client side caching return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): diff --git a/redis/connection.py b/redis/connection.py index 1f862d0371..cf2c7e97d5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,16 +9,9 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Sequence, Type, Union +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from ._cache import ( - DEFAULT_ALLOW_LIST, - DEFAULT_DENY_LIST, - DEFAULT_EVICTION_POLICY, - AbstractCache, - _LocalCache, -) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -33,7 +26,6 @@ TimeoutError, ) from .retry import Retry -from .typing import KeysT, ResponseT from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, @@ -158,13 +150,6 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Connection. @@ -230,18 +215,6 @@ def __init__( # p = DEFAULT_RESP_VERSION self.protocol = p self._command_packer = self._construct_command_packer(command_packer) - if cache_enabled: - _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy) - else: - _cache = None - self.client_cache = client_cache if client_cache is not None else _cache - if self.client_cache is not None: - if self.protocol not in [3, "3"]: - raise RedisError( - "client caching is only supported with protocol version 3 or higher" - ) - self.cache_deny_list = cache_deny_list - self.cache_allow_list = cache_allow_list def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -432,12 +405,6 @@ def on_connect(self): if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") - # if client caching is enabled, start tracking - if self.client_cache: - self.send_command("CLIENT", "TRACKING", "ON") - self.read_response() - self._parser.set_invalidation_push_handler(self._cache_invalidation_process) - def disconnect(self, *args): "Disconnects from the Redis server" self._parser.on_disconnect() @@ -458,9 +425,6 @@ def disconnect(self, *args): except OSError: pass - if self.client_cache: - self.client_cache.flush() - def _send_ping(self): """Send PING, expect PONG in return""" self.send_command("PING", check_health=False) @@ -608,60 +572,10 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - if data[1] is None: - self.client_cache.flush() - else: - for key in data[1]: - self.client_cache.invalidate_key(str_if_bytes(key)) - - def _get_from_local_cache(self, command: Sequence[str]): - """ - If the command is in the local cache, return the response - """ - if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list - ): - return None + def process_invalidation_messages(self): + print(f'connection {self} {id(self)} process invalidations') while self.can_read(): self.read_response(push_request=True) - return self.client_cache.get(command) - - def _add_to_local_cache( - self, command: Sequence[str], response: ResponseT, keys: List[KeysT] - ): - """ - Add the command and response to the local cache if the command - is allowed to be cached - """ - if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) - ): - self.client_cache.set(command, response, keys) - - def flush_cache(self): - if self.client_cache: - self.client_cache.flush() - - def delete_command_from_cache(self, command: Union[str, Sequence[str]]): - if self.client_cache: - self.client_cache.delete_command(command) - - def invalidate_key_from_cache(self, key: KeysT): - if self.client_cache: - self.client_cache.invalidate_key(key) class Connection(AbstractConnection): @@ -1110,6 +1024,14 @@ def __repr__(self) -> (str, str): f"({repr(self.connection_class(**self.connection_kwargs))})>" ) + def get_protocol(self): + """ + Returns: + The RESP protocol version, or ``None`` if the protocol is not specified, + in which case the server default will be used. + """ + return self.connection_kwargs.get("protocol", None) + def reset(self) -> None: self._lock = threading.Lock() self._created_connections = 0 @@ -1187,15 +1109,12 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": try: # ensure this connection is connected to Redis connection.connect() - # if client caching is not enabled connections that the pool - # provides should be ready to send a command. - # if not, the connection was either returned to the + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. - # (if caching enabled the connection will not always be ready - # to send a command because it may contain invalidation messages) try: - if connection.can_read() and connection.client_cache is None: + if connection.can_read(): raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() @@ -1281,27 +1200,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def flush_cache(self): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.flush_cache() - - def delete_command_from_cache(self, command: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.delete_command_from_cache(command) - - def invalidate_key_from_cache(self, key: str): - self._checkpid() - with self._lock: - connections = chain(self._available_connections, self._in_use_connections) - for connection in connections: - connection.invalidate_key_from_cache(key) - class BlockingConnectionPool(ConnectionPool): """ diff --git a/redis/sentinel.py b/redis/sentinel.py index 72b5bef548..e0437c81cd 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -252,7 +252,6 @@ def execute_command(self, *args, **kwargs): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - kwargs.pop("keys", None) # the keys are used only for client side caching once = bool(kwargs.get("once", False)) if "once" in kwargs.keys(): kwargs.pop("once") diff --git a/requirements.txt b/requirements.txt index 3274a80f62..26aed50b9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ async-timeout>=4.0.3 +cachetools diff --git a/tests/conftest.py b/tests/conftest.py index dd78bb6a2c..97d73773ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -441,7 +441,6 @@ def _gen_cluster_mock_resp(r, response): connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6e93407b4c..41b47b2268 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -146,7 +146,6 @@ def _gen_cluster_mock_resp(r, response): connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None with mock.patch.object(r, "connection", connection): yield r diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py deleted file mode 100644 index 7a7f881ce2..0000000000 --- a/tests/test_asyncio/test_cache.py +++ /dev/null @@ -1,408 +0,0 @@ -import time - -import pytest -import pytest_asyncio -from redis._cache import EvictionPolicy, _LocalCache -from redis.utils import HIREDIS_AVAILABLE - - -@pytest_asyncio.fixture -async def r(request, create_redis): - cache = request.param.get("cache") - kwargs = request.param.get("kwargs", {}) - r = await create_redis(protocol=3, client_cache=cache, **kwargs) - yield r, cache - - -@pytest_asyncio.fixture() -async def local_cache(): - yield _LocalCache() - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - @pytest.mark.onlynoncluster - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True) - async def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - async def test_cache_ttl(self, r): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], - indirect=True, - ) - async def test_cache_lfu_eviction(self, r): - r, cache = r - # add 3 keys to redis - await r.set("foo", "bar") - await r.set("foo2", "bar2") - await r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert await r.get("foo") == b"bar" - assert await r.get("foo2") == b"bar2" - assert await r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - await r.set("foo4", "bar4") - assert await r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - async def test_cache_allow_list(self, r): - r, cache = r - # add list to redis - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.llen("mylist") == 3 - assert await r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_cache_return_copy(self, r): - r, cache = r - await r.lpush("mylist", "foo", "bar", "baz") - assert await r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - async def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = await r.client_id() - await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1}) - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - id2 = await r.client_id() - - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [ - "1", - "1", - "1", - "1", - "1", - ] - - await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2}) - id3 = await r.client_id() - # client should get value from redis server post invalidate messages - assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"] - - await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3}) - # need to check that we get correct value 3 and not 2 - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"] - - await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4}) - # need to check that we get correct value 4 and not 3 - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - # client should get value from client cache - assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"] - id4 = await r.client_id() - assert id1 == id2 == id3 == id4 - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_delete_one_command(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_invalidate_key(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_flush_entire_cache(self, r): - r, cache = r - assert await r.mset({"a{a}": 1, "b{a}": 1}) is True - assert await r.set("c", 1) is True - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert await r.mget("a{a}", "b{a}") == ["1", "1"] - assert await r.get("c") == "1" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - async def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, r): - r, cache = r - await r.set("foo", "bar") - # get key from redis and save in local cache - assert await r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - await r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert await r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert await r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert await r.execute_command("SET", "b", "2") is True - assert ( - await r.execute_command("GET", "b") == "2" - ) # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: - - async def test_get_from_cache(self, local_cache, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, - ) - async def test_cache_decode_response(self, local_cache, sentinel_setup, master): - await master.set("foo", "bar") - # get key from redis and save in local cache - assert await master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - await master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - await master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert await master.get("foo") == "barbar" diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a36040f11b..57dfd25fb6 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -190,7 +190,6 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) @@ -201,7 +200,6 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc - connection._get_from_local_cache.return_value = None while node._free: node._free.pop() node._free.append(connection) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8f79f7d947..e584fc6999 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -75,7 +75,6 @@ async def call_with_retry(self, _, __): mock_conn = mock.AsyncMock(spec=Connection) mock_conn.retry = Retry_() - mock_conn._get_from_local_cache.return_value = None async def get_conn(_): # Validate only one client is created in single-client mode when diff --git a/tests/test_cache.py b/tests/test_cache.py index 022364e87a..4eda78ebbb 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,587 +1,35 @@ import time -from collections import defaultdict -from typing import List, Sequence, Union -import cachetools -import pytest -import redis -from redis import RedisError -from redis._cache import AbstractCache, EvictionPolicy, _LocalCache -from redis.typing import KeyT, ResponseT -from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import _get_client +from redis import Redis, RedisCluster -@pytest.fixture() -def r(request): - cache = request.param.get("cache") - kwargs = request.param.get("kwargs", {}) - protocol = request.param.get("protocol", 3) - single_connection_client = request.param.get("single_connection_client", False) - with _get_client( - redis.Redis, - request, - single_connection_client=single_connection_client, - protocol=protocol, - client_cache=cache, - **kwargs, - ) as client: - yield client, cache +def test_standalone_cached_get_and_set(): + r = Redis(use_cache=True, protocol=3) + assert r.set("key", 5) + assert r.get("key") == b"5" + r2 = Redis(protocol=3) + r2.set("key", "foo") -@pytest.fixture() -def local_cache(): - return _LocalCache() + time.sleep(0.5) + after_invalidation = r.get("key") + print(f'after invalidation {after_invalidation}') + assert after_invalidation == b"foo" -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -class TestLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3)}], - indirect=True, - ) - def test_cache_lru_eviction(self, r): - r, cache = r - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None +def test_cluster_cached_get_and_set(): + cluster_url = "redis://localhost:16379/0" - @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True) - def test_cache_ttl(self, r): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) + assert r.set("key", 5) + assert r.get("key") == b"5" - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}], - indirect=True, - ) - def test_cache_lfu_eviction(self, r): - r, cache = r - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # test the eviction policy - assert len(cache.cache) == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None + r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) + r2.set("key", "foo") - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_cache_decode_response(self, r): - r, cache = r - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}], - indirect=True, - ) - def test_cache_deny_list(self, r): - r, cache = r - # add list to redis - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) is None - assert cache.get(("LINDEX", "mylist", 1)) == b"bar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}], - indirect=True, - ) - def test_cache_allow_list(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.llen("mylist") == 3 - assert r.lindex("mylist", 1) == b"bar" - assert cache.get(("LLEN", "mylist")) == 3 - assert cache.get(("LINDEX", "mylist", 1)) is None - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_cache_return_copy(self, r): - r, cache = r - r.lpush("mylist", "foo", "bar", "baz") - assert r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"] - res = cache.get(("LRANGE", "mylist", 0, -1)) - assert res == [b"baz", b"bar", b"foo"] - res.append(b"new") - check = cache.get(("LRANGE", "mylist", 0, -1)) - assert check == [b"baz", b"bar", b"foo"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_csc_not_cause_disconnects(self, r): - r, cache = r - id1 = r.client_id() - r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}) - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - id2 = r.client_id() - - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"] - assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [ - "1", - "1", - "1", - "1", - "1", - "1", - ] - - r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2}) - id3 = r.client_id() - # client should get value from redis server post invalidate messages - assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"] - - r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3}) - # need to check that we get correct value 3 and not 2 - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"] - - r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4}) - # need to check that we get correct value 4 and not 3 - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - # client should get value from client cache - assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"] - id4 = r.client_id() - assert id1 == id2 == id3 == id4 - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_multiple_commands_same_key(self, r): - r, cache = r - r.mset({"a": 1, "b": 1}) - assert r.mget("a", "b") == ["1", "1"] - # value should be in local cache - assert cache.get(("MGET", "a", "b")) == ["1", "1"] - # set only one key - r.set("a", 2) - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("MGET", "a", "b")) is None - # get from redis - assert r.mget("a", "b") == ["2", "1"] - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_delete_one_command(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete one command from the cache - r.delete_command_from_cache(("MGET", "a{a}", "b{a}")) - # the other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_delete_several_commands(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # delete the commands from the cache - cache.delete_commands([("MGET", "a{a}", "b{a}"), ("GET", "c")]) - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_invalidate_key(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # invalidate one key from the cache - r.invalidate_key_from_cache("b{a}") - # one other command is still in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) == "1" - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_flush_entire_cache(self, r): - r, cache = r - r.mset({"a{a}": 1, "b{a}": 1}) - r.set("c", 1) - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - # values should be in local cache - assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"] - assert cache.get(("GET", "c")) == "1" - # flush the local cache - r.flush_cache() - # the commands are not in the local cache anymore - assert cache.get(("MGET", "a{a}", "b{a}")) is None - assert cache.get(("GET", "c")) is None - # get from redis - assert r.mget("a{a}", "b{a}") == ["1", "1"] - assert r.get("c") == "1" - - @pytest.mark.onlynoncluster - def test_cache_not_available_with_resp2(self, request): - with pytest.raises(RedisError) as e: - _get_client(redis.Redis, request, protocol=2, client_cache=_LocalCache()) - assert "protocol version 3 or higher" in str(e.value) - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_execute_command_args_not_split(self, r): - r, cache = r - assert r.execute_command("SET a 1") == "OK" - assert r.execute_command("GET a") == "1" - # "get a" is not whitelisted by default, the args should be separated - assert cache.get(("GET a",)) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "single_connection_client": True}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_single_connection(self, r): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache_invalidate_via_get(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # don't send any command to redis, just run another get - # it should process the invalidation in background - assert r.get("foo") == b"barbar" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlycluster -class TestClusterLocalCache: - @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_cache_decode_response(self, r): - r, cache = r - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == "bar" - # get key from local cache - assert cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # send any command to redis (process invalidation in background) - node = r.get_node_from_key("foo") - r.ping(target_nodes=node) - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == "barbar" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b", keys=["b"]) == "2" - assert cache.get(("GET", "b")) == "2" - - @pytest.mark.parametrize( - "r", - [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_execute_command_keys_not_provided(self, r): - r, cache = r - assert r.execute_command("SET", "b", "2") is True - assert r.execute_command("GET", "b") == "2" # keys not provided, not cached - assert cache.get(("GET", "b")) is None - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestSentinelLocalCache: - - def test_get_from_cache(self, local_cache, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == b"bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == b"barbar" - - @pytest.mark.parametrize( - "sentinel_setup", - [{"kwargs": {"decode_responses": True}}], - indirect=True, - ) - def test_cache_decode_response(self, local_cache, sentinel_setup, master): - master.set("foo", "bar") - # get key from redis and save in local cache - assert master.get("foo") == "bar" - # get key from local cache - assert local_cache.get(("GET", "foo")) == "bar" - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert local_cache.get(("GET", "foo")) is None - # get key from redis - assert master.get("foo") == "barbar" - - -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") -@pytest.mark.onlynoncluster -class TestCustomCache: - class _CustomCache(AbstractCache): - def __init__(self): - self.responses = cachetools.LRUCache(maxsize=1000) - self.keys_to_commands = defaultdict(list) - self.commands_to_keys = defaultdict(list) - - def set( - self, - command: Union[str, Sequence[str]], - response: ResponseT, - keys_in_command: List[KeyT], - ): - self.responses[command] = response - for key in keys_in_command: - self.keys_to_commands[key].append(tuple(command)) - self.commands_to_keys[command].append(tuple(keys_in_command)) - - def get(self, command: Union[str, Sequence[str]]) -> ResponseT: - return self.responses.get(command) - - def delete_command(self, command: Union[str, Sequence[str]]): - self.responses.pop(command, None) - keys = self.commands_to_keys.pop(command, []) - for key in keys: - if command in self.keys_to_commands[key]: - self.keys_to_commands[key].remove(command) - - def delete_commands(self, commands: List[Union[str, Sequence[str]]]): - for command in commands: - self.delete_command(command) - - def flush(self): - self.responses.clear() - self.commands_to_keys.clear() - self.keys_to_commands.clear() - - def invalidate_key(self, key: KeyT): - commands = self.keys_to_commands.pop(key, []) - for command in commands: - self.delete_command(command) - - @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True) - def test_get_from_cache(self, r, r2): - r, cache = r - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # send any command to redis (process invalidation in background) - r.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - # get key from redis - assert r.get("foo") == b"barbar" + time.sleep(0.5) + + after_invalidation = r.get("key") + print(f'after invalidation {after_invalidation}') + assert after_invalidation == b"foo" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 5a32bd6a7e..229e0fc6e6 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -208,7 +208,6 @@ def cmd_init_mock(self, r): def mock_node_resp(node, response): connection = Mock() connection.read_response.return_value = response - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -216,7 +215,6 @@ def mock_node_resp(node, response): def mock_node_resp_func(node, func): connection = Mock() connection.read_response.side_effect = func - connection._get_from_local_cache.return_value = None node.redis_connection.connection = connection return node @@ -485,7 +483,6 @@ def mock_execute_command(*_args, **_kwargs): redis_mock_node.execute_command.side_effect = mock_execute_command # Mock response value for all other commands redis_mock_node.parse_response.return_value = "MOCK_OK" - redis_mock_node.connection._get_from_local_cache.return_value = None for node in r.get_nodes(): if node.port != primary.port: node.redis_connection = redis_mock_node From 48607e9d46b6b3b503f0a08cb36e4a82f9c914e8 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 6 Aug 2024 16:57:14 +0300 Subject: [PATCH 02/78] Temporary refactor --- redis/cache.py | 107 +++++++++++- redis/client.py | 14 +- redis/connection.py | 401 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 433 insertions(+), 89 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 2a50565554..c79d5af6a9 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,5 +1,6 @@ from typing import Callable, TypeVar, Any, NoReturn, List, Union from typing import Optional +from enum import Enum from cachetools import TTLCache, Cache, LRUCache from cachetools.keys import hashkey @@ -9,6 +10,110 @@ T = TypeVar('T') +class EvictionPolicy(Enum): + LRU = "lru" + LFU = "lfu" + RANDOM = "random" + + +class CacheConfiguration: + DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU + + DEFAULT_ALLOW_LIST = [ + "BITCOUNT", + "BITFIELD_RO", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUSBYMEMBER_RO", + "GEORADIUS_RO", + "GEOSEARCH", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "JSON.ARRINDEX", + "JSON.ARRLEN", + "JSON.GET", + "JSON.MGET", + "JSON.OBJKEYS", + "JSON.OBJLEN", + "JSON.RESP", + "JSON.STRLEN", + "JSON.TYPE", + "LCS", + "LINDEX", + "LLEN", + "LPOS", + "LRANGE", + "MGET", + "SCARD", + "SDIFF", + "SINTER", + "SINTERCARD", + "SISMEMBER", + "SMEMBERS", + "SMISMEMBER", + "SORT_RO", + "STRLEN", + "SUBSTR", + "SUNION", + "TS.GET", + "TS.INFO", + "TS.RANGE", + "TS.REVRANGE", + "TYPE", + "XLEN", + "XPENDING", + "XRANGE", + "XREAD", + "XREVRANGE", + "ZCARD", + "ZCOUNT", + "ZDIFF", + "ZINTER", + "ZINTERCARD", + "ZLEXCOUNT", + "ZMSCORE", + "ZRANGE", + "ZRANGEBYLEX", + "ZRANGEBYSCORE", + "ZRANK", + "ZREVRANGE", + "ZREVRANGEBYLEX", + "ZREVRANGEBYSCORE", + "ZREVRANK", + "ZSCORE", + "ZUNION", + ] + + def __init__(self, **kwargs): + self._max_size = kwargs.get("cache_size", 10000) + self._ttl = kwargs.get("cache_ttl", 0) + self._eviction_policy = kwargs.get("eviction_policy", self.DEFAULT_EVICTION_POLICY) + + def get_ttl(self) -> int: + return self._ttl + + def get_eviction_policy(self) -> EvictionPolicy: + return self._eviction_policy + + def is_exceeds_max_size(self, count: int) -> bool: + return count > self._max_size + + def is_allowed_to_cache(self, command: str) -> bool: + return command in self.DEFAULT_ALLOW_LIST + + def ensure_string(key): if isinstance(key, bytes): return key.decode('utf-8') @@ -118,7 +223,7 @@ def _on_connect(self, conn): conn._parser.set_invalidation_push_handler(self._cache_invalidation_process) def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] + self, data: List[Union[str, Optional[List[str]]]] ) -> None: """ Invalidate (delete) all redis commands associated with a specific key. diff --git a/redis/client.py b/redis/client.py index 6d3a11a2ba..adbf380b8e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -311,6 +311,15 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) + if use_cache and protocol in [3, "3"]: + kwargs.update( + { + "use_cache": use_cache, + "cache": cache, + "cache_size": cache_size, + "cache_ttl": cache_ttl, + } + ) connection_pool = ConnectionPool(**kwargs) self.auto_close_connection_pool = True else: @@ -320,7 +329,6 @@ def __init__( if use_cache and self.connection_pool.get_protocol() not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") - CacheMixin.__init__(self, use_cache, self.connection_pool, cache, cache_size, cache_ttl) self.connection = None if single_connection_client: @@ -535,7 +543,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response """ - conn.send_command(*args) + conn.send_command(*args, **options) return self.parse_response(conn, command_name, **options) def _disconnect_raise(self, conn, error): @@ -553,8 +561,6 @@ def _disconnect_raise(self, conn, error): # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): - if self.use_cache: - return self.cached_call(self._execute_command, *args, **options) return self._execute_command(*args, **options) def _execute_command(self, *args, **options): diff --git a/redis/connection.py b/redis/connection.py index cf2c7e97d5..5d3640e816 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,6 +11,9 @@ from time import time from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse +from cachetools import TTLCache, Cache, LRUCache +from cachetools.keys import hashkey +from redis.cache import CacheConfiguration from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff @@ -99,9 +102,9 @@ def pack(self, *args): # output list if we're sending large values or memoryviews arg_length = len(arg) if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) ): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) @@ -124,32 +127,96 @@ def pack(self, *args): return output -class AbstractConnection: +class ConnectionInterface: + @abstractmethod + def repr_pieces(self): + pass + + @abstractmethod + def register_connect_callback(self, callback): + pass + + @abstractmethod + def deregister_connect_callback(self, callback): + pass + + @abstractmethod + def set_parser(self, parser_class): + pass + + @abstractmethod + def connect(self): + pass + + @abstractmethod + def on_connect(self): + pass + + @abstractmethod + def disconnect(self, *args): + pass + + @abstractmethod + def check_health(self): + pass + + @abstractmethod + def send_packed_command(self, command, check_health=True): + pass + + @abstractmethod + def send_command(self, *args, **kwargs): + pass + + @abstractmethod + def can_read(self, timeout=0): + pass + + @abstractmethod + def read_response( + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, + ): + pass + + @abstractmethod + def pack_command(self, *args): + pass + + @abstractmethod + def pack_commands(self, commands): + pass + + +class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" def __init__( - self, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - retry_on_timeout: bool = False, - retry_on_error=SENTINEL, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - parser_class=DefaultParser, - socket_read_size: int = 65536, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "redis-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - retry: Union[Any, None] = None, - redis_connect_func: Optional[Callable[[], None]] = None, - credential_provider: Optional[CredentialProvider] = None, - protocol: Optional[int] = 2, - command_packer: Optional[Callable[[], None]] = None, + self, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error=SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. @@ -324,8 +391,8 @@ def on_connect(self): # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( - self.credential_provider - or UsernamePasswordCredentialProvider(self.username, self.password) + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() @@ -373,8 +440,8 @@ def on_connect(self): self.send_command("HELLO", self.protocol) response = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -493,11 +560,11 @@ def can_read(self, timeout=0): raise ConnectionError(f"Error while reading from {host_error}: {e.args}") def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): """Read the response from a previously sent command""" @@ -553,9 +620,9 @@ def pack_commands(self, commands): for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( - buffer_length > buffer_cutoff - or chunklen > buffer_cutoff - or isinstance(chunk, memoryview) + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) ): if pieces: output.append(SYM_EMPTY.join(pieces)) @@ -572,23 +639,21 @@ def pack_commands(self, commands): output.append(SYM_EMPTY.join(pieces)) return output - def process_invalidation_messages(self): - print(f'connection {self} {id(self)} process invalidations') - while self.can_read(): - self.read_response(push_request=True) + def get_protocol(self) -> int or str: + return self.protocol class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" def __init__( - self, - host="localhost", - port=6379, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, - **kwargs, + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, ): self.host = host self.port = int(port) @@ -610,7 +675,7 @@ def _connect(self): # socket.connect() err = None for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -648,6 +713,151 @@ def _host_error(self): return f"{self.host}:{self.port}" +def ensure_string(key): + if isinstance(key, bytes): + return key.decode('utf-8') + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") + + +class CacheProxyConnection(ConnectionInterface): + def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration): + self.pid = os.getpid() + self._conn = conn + self.retry = self._conn.retry + self._cache = cache + self._conf = conf + self._current_command_hash = None + self._current_command_keys = None + self._current_options = None + self._keys_mapping = LRUCache(maxsize=10000) + self.register_connect_callback(self._enable_tracking_callback) + + def repr_pieces(self): + return self._conn.repr_pieces() + + def register_connect_callback(self, callback): + self._conn.register_connect_callback(callback) + + def deregister_connect_callback(self, callback): + self._conn.deregister_connect_callback(callback) + + def set_parser(self, parser_class): + self._conn.set_parser(parser_class) + + def connect(self): + self._conn.connect() + + def on_connect(self): + self._conn.on_connect() + + def disconnect(self, *args): + self._conn.disconnect(*args) + + def check_health(self): + self._conn.check_health() + + def send_packed_command(self, command, check_health=True): + cache_key = hashkey(command) + + if self._cache.get(cache_key): + self._current_command_hash = cache_key + return + + self._current_command_hash = None + self._conn.send_packed_command(command) + + def send_command(self, *args, **kwargs): + if not self._conf.is_allowed_to_cache(args[0]): + self._current_command_hash = None + self._current_command_keys = None + self._conn.send_command(*args, **kwargs) + return + + self._current_command_hash = hashkey(*args) + + if kwargs.get("keys"): + self._current_command_keys = kwargs["keys"] + + if not isinstance(self._current_command_keys, list): + raise TypeError("Cache keys must be a list.") + + if self._cache.get(self._current_command_hash): + return + + self._conn.send_command(*args, **kwargs) + + def can_read(self, timeout=0): + return self._conn.can_read(timeout) + + def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): + response = self._conn.read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request + ) + + if isinstance(response, List) and len(response) > 0 and response[0] == 'invalidate': + self._on_invalidation_callback(response) + self.read_response( + disable_decoding=disable_decoding, + disconnect_on_error=disconnect_on_error, + push_request=push_request + ) + + if response is None or self._current_command_hash is None: + return response + + if self._current_command_hash in self._cache: + return self._cache[self._current_command_hash] + + for key in self._current_command_keys: + if key in self._keys_mapping: + if self._current_command_hash not in self._keys_mapping[key]: + self._keys_mapping[key].append(self._current_command_hash) + else: + self._keys_mapping[key] = [self._current_command_hash] + + self._cache[self._current_command_hash] = response + return response + + def pack_command(self, *args): + pass + + def pack_commands(self, commands): + pass + + def _connect(self): + self._conn._connect() + + def _host_error(self): + self._conn._host_error() + + def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: + conn.send_command('CLIENT', 'TRACKING', 'ON') + conn.read_response() + + def _process_pending_invalidations(self): + print(f'connection {self} {id(self)} process invalidations') + while self.can_read(): + self.read_response(push_request=True) + + def _on_invalidation_callback( + self, data: List[Union[str, Optional[List[str]]]] + ): + if data[1] is None: + self._cache.clear() + else: + for key in data[1]: + normalized_key = ensure_string(key) + if normalized_key in self._keys_mapping: + for cache_key in self._keys_mapping[normalized_key]: + self._cache.pop(cache_key) + + + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). This class extends the Connection class, adding SSL functionality, and making @@ -655,22 +865,22 @@ class SSLConnection(Connection): """ # noqa def __init__( - self, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_ca_path=None, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - **kwargs, + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_ca_path=None, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + ssl_min_version=None, + ssl_ciphers=None, + **kwargs, ): """Constructor @@ -757,9 +967,9 @@ def _wrap_socket_with_ssl(self, sock): password=self.certificate_password, ) if ( - self.ca_certs is not None - or self.ca_path is not None - or self.ca_data is not None + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None ): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data @@ -875,9 +1085,9 @@ def to_bool(value): def parse_url(url): if not ( - url.startswith("redis://") - or url.startswith("rediss://") - or url.startswith("unix://") + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") ): raise ValueError( "Redis URL must specify one of the following " @@ -994,18 +1204,37 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, - connection_class=Connection, - max_connections: Optional[int] = None, - **connection_kwargs, + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): - max_connections = max_connections or 2**31 + max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections + self._cache = None + self._cache_conf = None + + if connection_kwargs.get("use_cache"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + + self._cache_conf = CacheConfiguration(**self.connection_kwargs) + + if self.connection_kwargs.get("cache"): + self._cache = self.connection_kwargs.get("cache") + else: + self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) + + connection_kwargs.pop("use_cache", None) + connection_kwargs.pop("cache_size", None) + connection_kwargs.pop("cache_ttl", None) + connection_kwargs.pop("cache", None) + # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1138,11 +1367,15 @@ def get_encoder(self) -> Encoder: decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self) -> "Connection": + def make_connection(self) -> "ConnectionInterface": "Create a new connection" if self._created_connections >= self.max_connections: raise ConnectionError("Too many connections") self._created_connections += 1 + + if self._cache is not None and self._cache_conf is not None: + return CacheProxyConnection(self.connection_class(**self.connection_kwargs), self._cache, self._cache_conf) + return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1236,12 +1469,12 @@ class BlockingConnectionPool(ConnectionPool): """ def __init__( - self, - max_connections=50, - timeout=20, - connection_class=Connection, - queue_class=LifoQueue, - **connection_kwargs, + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout From 2c0c8122261c92db47a9ac214d00d22ee0ca6374 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 7 Aug 2024 11:38:25 +0300 Subject: [PATCH 03/78] Finished CacheProxyConnection implementation, added comments --- redis/connection.py | 49 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 5d3640e816..7335c94358 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -760,59 +760,56 @@ def check_health(self): self._conn.check_health() def send_packed_command(self, command, check_health=True): - cache_key = hashkey(command) - - if self._cache.get(cache_key): - self._current_command_hash = cache_key - return - - self._current_command_hash = None + self._process_pending_invalidations() + # TODO: Investigate if it's possible to unpack command or extract keys from packed command self._conn.send_packed_command(command) def send_command(self, *args, **kwargs): + self._process_pending_invalidations() + + # If command is write command or not allowed to cache skip it. if not self._conf.is_allowed_to_cache(args[0]): self._current_command_hash = None self._current_command_keys = None self._conn.send_command(*args, **kwargs) return + # Create hash representation of current executed command. self._current_command_hash = hashkey(*args) + # Extract keys from current command. if kwargs.get("keys"): self._current_command_keys = kwargs["keys"] if not isinstance(self._current_command_keys, list): raise TypeError("Cache keys must be a list.") + # If current command reply already cached prevent sending data over socket. if self._cache.get(self._current_command_hash): return + # Send command over socket only if it's read-only command that not yet cached. self._conn.send_command(*args, **kwargs) def can_read(self, timeout=0): return self._conn.can_read(timeout) def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): + # Check if command response exists in a cache. + if self._current_command_hash in self._cache: + return self._cache[self._current_command_hash] + response = self._conn.read_response( disable_decoding=disable_decoding, disconnect_on_error=disconnect_on_error, push_request=push_request ) - if isinstance(response, List) and len(response) > 0 and response[0] == 'invalidate': - self._on_invalidation_callback(response) - self.read_response( - disable_decoding=disable_decoding, - disconnect_on_error=disconnect_on_error, - push_request=push_request - ) - + # Check if command that was sent is write command to prevent caching of write replies. if response is None or self._current_command_hash is None: return response - if self._current_command_hash in self._cache: - return self._cache[self._current_command_hash] - + # Create separate mapping for keys or add current response to associated keys. for key in self._current_command_keys: if key in self._keys_mapping: if self._current_command_hash not in self._keys_mapping[key]: @@ -824,10 +821,10 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus return response def pack_command(self, *args): - pass + return self._conn.pack_command(*args) def pack_commands(self, commands): - pass + return self._conn.pack_commands(commands) def _connect(self): self._conn._connect() @@ -838,24 +835,27 @@ def _host_error(self): def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: conn.send_command('CLIENT', 'TRACKING', 'ON') conn.read_response() + conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) def _process_pending_invalidations(self): - print(f'connection {self} {id(self)} process invalidations') while self.can_read(): - self.read_response(push_request=True) + self._conn.read_response(push_request=True) def _on_invalidation_callback( self, data: List[Union[str, Optional[List[str]]]] ): + # Flush cache when DB flushed on server-side if data[1] is None: self._cache.clear() else: for key in data[1]: normalized_key = ensure_string(key) if normalized_key in self._keys_mapping: + # Make sure that all command responses associated with this key will be deleted for cache_key in self._keys_mapping[normalized_key]: self._cache.pop(cache_key) - + # Removes key from mapping cache + self._keys_mapping.pop(normalized_key) class SSLConnection(Connection): @@ -1235,7 +1235,6 @@ def __init__( connection_kwargs.pop("cache_ttl", None) connection_kwargs.pop("cache", None) - # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as # after a fork. during this time, multiple threads in the child @@ -1343,7 +1342,7 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read(): + if connection.can_read() and self._cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() From a7343036a6876722d1b2a8332b064eceb7080544 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 7 Aug 2024 15:31:15 +0300 Subject: [PATCH 04/78] Added test cases and scheduler dependency --- redis/client.py | 3 +- redis/connection.py | 61 +++++++++++--- requirements.txt | 1 + tests/test_cache.py | 199 +++++++++++++++++++++++++++++++++++++++----- 4 files changed, 228 insertions(+), 36 deletions(-) diff --git a/redis/client.py b/redis/client.py index adbf380b8e..532c48b646 100755 --- a/redis/client.py +++ b/redis/client.py @@ -15,7 +15,6 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import CacheMixin from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -86,7 +85,7 @@ class AbstractRedis: pass -class Redis(RedisModuleCommands, CoreCommands, SentinelCommands, CacheMixin): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ Implementation of the Redis protocol. diff --git a/redis/connection.py b/redis/connection.py index 7335c94358..9493303e59 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,6 +11,8 @@ from time import time from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse + +from apscheduler.schedulers.background import BackgroundScheduler from cachetools import TTLCache, Cache, LRUCache from cachetools.keys import hashkey from redis.cache import CacheConfiguration @@ -788,15 +790,21 @@ def send_command(self, *args, **kwargs): if self._cache.get(self._current_command_hash): return - # Send command over socket only if it's read-only command that not yet cached. + # Set temporary entry as a status to prevent race condition from another connection. + self._cache[self._current_command_hash] = "caching-in-progress" + + # Send command over socket only if it's allowed read-only command that not yet cached. self._conn.send_command(*args, **kwargs) def can_read(self, timeout=0): return self._conn.can_read(timeout) def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): - # Check if command response exists in a cache. - if self._current_command_hash in self._cache: + # Check if command response exists in a cache and it's not in progress. + if ( + self._current_command_hash in self._cache + and self._cache[self._current_command_hash] != "caching-in-progress" + ): return self._cache[self._current_command_hash] response = self._conn.read_response( @@ -805,8 +813,12 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus push_request=push_request ) - # Check if command that was sent is write command to prevent caching of write replies. - if response is None or self._current_command_hash is None: + # If response is None prevent from caching and remove temporary cache entry. + if response is None: + self._cache.pop(self._current_command_hash) + return response + # Prevent not-allowed command from caching. + elif self._current_command_hash is None: return response # Create separate mapping for keys or add current response to associated keys. @@ -817,7 +829,12 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus else: self._keys_mapping[key] = [self._current_command_hash] - self._cache[self._current_command_hash] = response + cache_entry = self._cache.get(self._current_command_hash, None) + + # Cache only responses that still valid and wasn't invalidated by another connection in meantime + if cache_entry is not None: + self._cache[self._current_command_hash] = response + return response def pack_command(self, *args): @@ -1218,6 +1235,7 @@ def __init__( self.max_connections = max_connections self._cache = None self._cache_conf = None + self._scheduler = None if connection_kwargs.get("use_cache"): if connection_kwargs.get("protocol") not in [3, "3"]: @@ -1225,15 +1243,20 @@ def __init__( self._cache_conf = CacheConfiguration(**self.connection_kwargs) - if self.connection_kwargs.get("cache"): - self._cache = self.connection_kwargs.get("cache") + cache = self.connection_kwargs.get("cache") + if cache is not None: + self._cache = cache else: self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) - connection_kwargs.pop("use_cache", None) - connection_kwargs.pop("cache_size", None) - connection_kwargs.pop("cache_ttl", None) - connection_kwargs.pop("cache", None) + # self.scheduler = BackgroundScheduler() + # self.scheduler.add_job(self._perform_health_check, "interval", seconds=2) + # self.scheduler.start() + + connection_kwargs.pop("use_cache", None) + connection_kwargs.pop("cache_size", None) + connection_kwargs.pop("cache_ttl", None) + connection_kwargs.pop("cache", None) # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1246,6 +1269,10 @@ def __init__( self._fork_lock = threading.Lock() self.reset() + def __del__(self): + if self._scheduler is not None: + self.scheduler.shutdown() + def __repr__(self) -> (str, str): return ( f"<{type(self).__module__}.{type(self).__name__}" @@ -1432,6 +1459,16 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + def _perform_health_check(self) -> None: + self._checkpid() + with self._lock: + while self._available_connections: + conn = self._available_connections.pop() + self._in_use_connections.add(conn) + conn.send_command('PING') + conn.read_response() + self.release(conn) + class BlockingConnectionPool(ConnectionPool): """ diff --git a/requirements.txt b/requirements.txt index 26aed50b9d..98c67e1a42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ async-timeout>=4.0.3 cachetools +apscheduler diff --git a/tests/test_cache.py b/tests/test_cache.py index 4eda78ebbb..158708ed05 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,35 +1,190 @@ import time +import pytest +from cachetools import TTLCache, LRUCache, LFUCache + +import redis from redis import Redis, RedisCluster +from redis.utils import HIREDIS_AVAILABLE +from tests.conftest import _get_client + + +@pytest.fixture() +def r(request): + use_cache = request.param.get("use_cache", False) + cache = request.param.get("cache") + kwargs = request.param.get("kwargs", {}) + protocol = request.param.get("protocol", 3) + single_connection_client = request.param.get("single_connection_client", False) + with _get_client( + redis.Redis, + request, + protocol=protocol, + single_connection_client=single_connection_client, + use_cache=use_cache, + cache=cache, + **kwargs, + ) as client: + yield client, cache + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +class TestCache: + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(3), "use_cache": True}], + indirect=True, + ) + def test_cache_lru_eviction(self, r): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + def test_cache_ttl(self, r, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + @pytest.mark.parametrize( + "r", + [{"cache": LFUCache(3), "use_cache": True}], + indirect=True, + ) + def test_cache_lfu_eviction(self, r, cache): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert cache.currsize == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None -def test_standalone_cached_get_and_set(): - r = Redis(use_cache=True, protocol=3) - assert r.set("key", 5) - assert r.get("key") == b"5" + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_ignore_not_allowed_command(self, r): + r, cache = r + # add fields to hash + assert r.hset("foo", "bar", "baz") + # get random field + assert r.hrandfield("foo") == b"bar" + assert cache.get(("HRANDFIELD", "foo")) is None - r2 = Redis(protocol=3) - r2.set("key", "foo") + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_invalidate_all_related_responses(self, r, cache): + r, cache = r + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") - time.sleep(0.5) + # Make sure that replies was cached + assert r.mget("foo", "bar") == [b"bar", b"foo"] + assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"] - after_invalidation = r.get("key") - print(f'after invalidation {after_invalidation}') - assert after_invalidation == b"foo" + # Invalidate one of the keys and make sure that all associated cached entries was removed + assert r.set("foo", "baz") + assert r.get("foo") == b"baz" + assert cache.get(("MGET", "foo", "bar")) is None + assert cache.get(("GET", "foo")) == b"baz" + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_flushed_on_server_flush(self, r, cache): + r, cache = r + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") + assert r.set("baz", "bar") -def test_cluster_cached_get_and_set(): - cluster_url = "redis://localhost:16379/0" + # Make sure that replies was cached + assert r.get("foo") == b"bar" + assert r.get("bar") == b"foo" + assert r.get("baz") == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + assert cache.get(("GET", "baz")) == b"bar" - r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) - assert r.set("key", 5) - assert r.get("key") == b"5" + # Flush server and trying to access cached entry + assert r.flushall() + assert r.get("foo") is None + assert cache.currsize == 0 - r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) - r2.set("key", "foo") - time.sleep(0.5) - - after_invalidation = r.get("key") - print(f'after invalidation {after_invalidation}') - assert after_invalidation == b"foo" +# def test_cluster_cached_get_and_set(): +# cluster_url = "redis://localhost:16379/0" +# +# r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) +# assert r.set("key", 5) +# assert r.get("key") == b"5" +# +# r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) +# r2.set("key", "foo") +# +# time.sleep(0.5) +# +# after_invalidation = r.get("key") +# print(f'after invalidation {after_invalidation}') +# assert after_invalidation == b"foo" From 339735cd544a540c426a2a8597d48a37038c9b37 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 8 Aug 2024 16:27:46 +0300 Subject: [PATCH 05/78] Added support for RedisCluster and multi-threaded test cases --- redis/cluster.py | 50 ++++--- redis/connection.py | 17 +-- tests/conftest.py | 2 +- tests/test_cache.py | 334 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 358 insertions(+), 45 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 39e8c4b9ea..9d135a3c0c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -11,7 +11,6 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import CacheMixin from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -446,7 +445,7 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: self.nodes_manager.default_node = random.choice(replicas) -class RedisCluster(AbstractRedisCluster, RedisClusterCommands, CacheMixin): +class RedisCluster(AbstractRedisCluster, RedisClusterCommands): @classmethod def from_url(cls, url, **kwargs): """ @@ -504,7 +503,10 @@ def __init__( dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - use_cache: Optional[bool] = False, + use_cache: bool = False, + cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, **kwargs, ): """ @@ -628,6 +630,10 @@ def __init__( kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) + protocol = kwargs.get("protocol", None) + if use_cache and protocol not in [3, "3"]: + raise RedisError("Client caching is only supported with RESP version 3") + self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() @@ -641,6 +647,9 @@ def __init__( dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, use_cache=use_cache, + cache=cache, + cache_size=cache_size, + cache_ttl=cache_ttl, **kwargs, ) @@ -649,11 +658,6 @@ def __init__( ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) - protocol = kwargs.get("protocol", None) - if use_cache and protocol not in [3, "3"]: - raise RedisError("Client caching is only supported with RESP version 3") - CacheMixin.__init__(self, use_cache, None) - self.commands_parser = CommandsParser(self) self._lock = threading.Lock() @@ -1057,8 +1061,6 @@ def _parse_target_nodes(self, target_nodes): return nodes def execute_command(self, *args, **options): - if self.use_cache: - return self.cached_call(self._execute_command, *args, **options) return self._internal_execute_command(*args, **options) def _internal_execute_command(self, *args, **kwargs): @@ -1163,7 +1165,7 @@ def _execute_command(self, target_node, *args, **kwargs): connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - connection.send_command(*args) + connection.send_command(*args, **kwargs) response = redis_node.parse_response(connection, command, **kwargs) if command in self.cluster_response_callbacks: response = self.cluster_response_callbacks[command]( @@ -1317,7 +1319,7 @@ def reset(self) -> None: self.primary_to_idx.clear() -class NodesManager(CacheMixin): +class NodesManager(): def __init__( self, startup_nodes, @@ -1327,8 +1329,10 @@ def __init__( dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - use_cache: Optional[bool] = False, + use_cache: bool = False, cache: Optional[Cache] = None, + cache_size: int = 128, + cache_ttl: int = 300, **kwargs, ): self.nodes_cache = {} @@ -1342,13 +1346,15 @@ def __init__( self.connection_pool_class = connection_pool_class self.address_remap = address_remap self.use_cache = use_cache + self.cache = cache + self.cache_size = cache_size + self.cache_ttl = cache_ttl self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() if lock is None: lock = threading.Lock() self._lock = lock - CacheMixin.__init__(self, use_cache, None, cache) self.initialize() def get_node(self, host=None, port=None, node_name=None): @@ -1486,9 +1492,21 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - r = Redis(connection_pool=self.connection_pool_class(**kwargs), use_cache=self.use_cache, cache=self.cache) + kwargs.update({"use_cache": self.use_cache}) + kwargs.update({"cache": self.cache}) + kwargs.update({"cache_size": self.cache_size}) + kwargs.update({"cache_ttl": self.cache_ttl}) + r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: - r = Redis(host=host, port=port, use_cache=self.use_cache, cache=self.cache, **kwargs) + r = Redis( + host=host, + port=port, + use_cache=self.use_cache, + cache=self.cache, + cache_size=self.cache_size, + cache_ttl=self.cache_ttl, + **kwargs, + ) return r def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): diff --git a/redis/connection.py b/redis/connection.py index 9493303e59..011885d16d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,7 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import time +from time import time, sleep from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse @@ -756,6 +756,7 @@ def on_connect(self): self._conn.on_connect() def disconnect(self, *args): + self._cache.clear() self._conn.disconnect(*args) def check_health(self): @@ -1235,7 +1236,7 @@ def __init__( self.max_connections = max_connections self._cache = None self._cache_conf = None - self._scheduler = None + self.scheduler = None if connection_kwargs.get("use_cache"): if connection_kwargs.get("protocol") not in [3, "3"]: @@ -1249,9 +1250,9 @@ def __init__( else: self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) - # self.scheduler = BackgroundScheduler() - # self.scheduler.add_job(self._perform_health_check, "interval", seconds=2) - # self.scheduler.start() + self.scheduler = BackgroundScheduler() + self.scheduler.add_job(self._perform_health_check, "interval", seconds=2, id="cache_health_check") + self.scheduler.start() connection_kwargs.pop("use_cache", None) connection_kwargs.pop("cache_size", None) @@ -1269,10 +1270,6 @@ def __init__( self._fork_lock = threading.Lock() self.reset() - def __del__(self): - if self._scheduler is not None: - self.scheduler.shutdown() - def __repr__(self) -> (str, str): return ( f"<{type(self).__module__}.{type(self).__name__}" @@ -1464,10 +1461,8 @@ def _perform_health_check(self) -> None: with self._lock: while self._available_connections: conn = self._available_connections.pop() - self._in_use_connections.add(conn) conn.send_command('PING') conn.read_response() - self.release(conn) class BlockingConnectionPool(ConnectionPool): diff --git a/tests/conftest.py b/tests/conftest.py index 97d73773ba..0222164332 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/0" +default_redis_url = "redis://localhost:6372/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" diff --git a/tests/test_cache.py b/tests/test_cache.py index 158708ed05..96ca68a819 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,3 +1,4 @@ +import threading import time import pytest @@ -28,11 +29,16 @@ def r(request): yield client, cache +def set_get(client, key, value): + client.set(key, value) + return client.get(key) + + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestCache: @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2): + def test_get_from_cache(self, r, r2, cache): r, cache = r # add key to redis r.set("foo", "bar") @@ -47,12 +53,94 @@ def test_get_from_cache(self, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_get_from_cache_multithreaded(self, r, cache): + r, cache = r + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "bar")).start() + threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that both values was cached. + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "baz")).start() + threading.Thread(target=set_get, args=(r, "bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that new values was cached. + assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(("GET", "bar")) == b"bar" + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache(self, r, r2, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Wait for health check + time.sleep(2) + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): + r, cache = r + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "bar")).start() + threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + # Wait for command execution to be finished + time.sleep(0.1) + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + # change key in redis (cause invalidation) + threading.Thread(target=r2.set, args=("foo", "baz")).start() + threading.Thread(target=r2.set, args=("bar", "bar")).start() + # Wait for health check + time.sleep(2) + # Trigger object destructor to shutdown health check thread + del r + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + assert cache.get(("GET", "bar")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_cache_clears_on_disconnect(self, r, r2, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # Force disconnection + r.connection_pool.get_connection('_').disconnect() + # Make sure cache is empty + assert cache.currsize == 0 + @pytest.mark.parametrize( "r", [{"cache": LRUCache(3), "use_cache": True}], indirect=True, ) - def test_cache_lru_eviction(self, r): + @pytest.mark.onlynoncluster + def test_cache_lru_eviction(self, r, cache): r, cache = r # add 3 keys to redis r.set("foo", "bar") @@ -73,6 +161,7 @@ def test_cache_lru_eviction(self, r): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster def test_cache_ttl(self, r, cache): r, cache = r # add key to redis @@ -91,6 +180,7 @@ def test_cache_ttl(self, r, cache): [{"cache": LFUCache(3), "use_cache": True}], indirect=True, ) + @pytest.mark.onlynoncluster def test_cache_lfu_eviction(self, r, cache): r, cache = r # add 3 keys to redis @@ -118,6 +208,7 @@ def test_cache_lfu_eviction(self, r, cache): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlynoncluster def test_cache_ignore_not_allowed_command(self, r): r, cache = r # add fields to hash @@ -131,6 +222,7 @@ def test_cache_ignore_not_allowed_command(self, r): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlynoncluster def test_cache_invalidate_all_related_responses(self, r, cache): r, cache = r # Add keys @@ -152,6 +244,7 @@ def test_cache_invalidate_all_related_responses(self, r, cache): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlynoncluster def test_cache_flushed_on_server_flush(self, r, cache): r, cache = r # Add keys @@ -173,18 +266,225 @@ def test_cache_flushed_on_server_flush(self, r, cache): assert cache.currsize == 0 -# def test_cluster_cached_get_and_set(): -# cluster_url = "redis://localhost:16379/0" -# -# r = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) -# assert r.set("key", 5) -# assert r.get("key") == b"5" -# -# r2 = RedisCluster.from_url(cluster_url, use_cache=True, protocol=3) -# r2.set("key", "foo") -# -# time.sleep(0.5) -# -# after_invalidation = r.get("key") -# print(f'after invalidation {after_invalidation}') -# assert after_invalidation == b"foo" +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlycluster +class TestClusterCache: + @pytest.mark.parametrize("r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True) + def test_get_from_cache(self, r, r2): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_get_from_cache_multithreaded(self, r, r2, cache): + r, cache = r + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "bar")).start() + threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that both values was cached. + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "baz")).start() + threading.Thread(target=set_get, args=(r, "bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that new values was cached. + assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(("GET", "bar")) == b"bar" + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache(self, r, r2, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Wait for health check + time.sleep(2) + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): + r, cache = r + # Running commands over two threads + threading.Thread(target=set_get, args=(r, "foo", "bar")).start() + threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + # Wait for command execution to be finished + time.sleep(0.1) + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + # change key in redis (cause invalidation) + threading.Thread(target=r2.set, args=("foo", "baz")).start() + threading.Thread(target=r2.set, args=("bar", "bar")).start() + # Wait for health check + time.sleep(2) + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + assert cache.get(("GET", "bar")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.onlynoncluster + def test_cache_clears_on_disconnect(self, r, r2, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # Force disconnection + r.nodes_manager.get_node_from_slot(10).redis_connection.connection_pool.get_connection("_").disconnect() + # Make sure cache is empty + assert cache.currsize == 0 + + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(3), "use_cache": True}], + indirect=True, + ) + def test_cache_lru_eviction(self, r, cache): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # get the 3 keys from local cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) == b"bar2" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # the first key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + def test_cache_ttl(self, r, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # wait for the key to expire + time.sleep(1) + # the key is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize( + "r", + [{"cache": LFUCache(3), "use_cache": True}], + indirect=True, + ) + def test_cache_lfu_eviction(self, r, cache): + r, cache = r + # add 3 keys to redis + r.set("foo", "bar") + r.set("foo2", "bar2") + r.set("foo3", "bar3") + # get 3 keys from redis and save in local cache + assert r.get("foo") == b"bar" + assert r.get("foo2") == b"bar2" + assert r.get("foo3") == b"bar3" + # change the order of the keys in the cache + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo3")) == b"bar3" + # add 1 more key to redis (exceed the max size) + r.set("foo4", "bar4") + assert r.get("foo4") == b"bar4" + # test the eviction policy + assert cache.currsize == 3 + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "foo2")) is None + + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_ignore_not_allowed_command(self, r): + r, cache = r + # add fields to hash + assert r.hset("foo", "bar", "baz") + # get random field + assert r.hrandfield("foo") == b"bar" + assert cache.get(("HRANDFIELD", "foo")) is None + + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_invalidate_all_related_responses(self, r, cache): + r, cache = r + # Add keys + assert r.set("foo{slot}", "bar") + assert r.set("bar{slot}", "foo") + + # Make sure that replies was cached + assert r.mget("foo{slot}", "bar{slot}") == [b"bar", b"foo"] + assert cache.get(("MGET", "foo{slot}", "bar{slot}")) == [b"bar", b"foo"] + + # Invalidate one of the keys and make sure that all associated cached entries was removed + assert r.set("foo{slot}", "baz") + assert r.get("foo{slot}") == b"baz" + assert cache.get(("MGET", "foo{slot}", "bar{slot}")) is None + assert cache.get(("GET", "foo{slot}")) == b"baz" + + @pytest.mark.parametrize( + "r", + [{"cache": LRUCache(maxsize=128), "use_cache": True}], + indirect=True, + ) + def test_cache_flushed_on_server_flush(self, r, cache): + r, cache = r + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") + assert r.set("baz", "bar") + + # Make sure that replies was cached + assert r.get("foo") == b"bar" + assert r.get("bar") == b"foo" + assert r.get("baz") == b"bar" + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + assert cache.get(("GET", "baz")) == b"bar" + + # Flush server and trying to access cached entry + assert r.flushall() + assert r.get("foo") is None + assert cache.currsize == 0 From 88f7e549df69ac3f5ed8ed0a6e80cfd339c22adb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 8 Aug 2024 16:37:58 +0300 Subject: [PATCH 06/78] Added support for BlockingConnectionPool --- redis/connection.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index 011885d16d..82aa4ab40e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1541,7 +1541,14 @@ def reset(self): def make_connection(self): "Make a fresh connection." - connection = self.connection_class(**self.connection_kwargs) + if self._cache is not None and self._cache_conf is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self._cache, + self._cache_conf + ) + else: + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection From 936be8489d619287bdd0b727a705c7c65ce751b6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 8 Aug 2024 16:44:01 +0300 Subject: [PATCH 07/78] Fixed docker-compose command --- tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks.py b/tasks.py index 0fef093f3c..76737b8eff 100644 --- a/tasks.py +++ b/tasks.py @@ -13,7 +13,7 @@ def devenv(c): """Brings up the test environment, by wrapping docker compose.""" clean(c) - cmd = "docker-compose --profile all up -d --build" + cmd = "docker compose --profile all up -d --build" run(cmd) @@ -85,7 +85,7 @@ def clean(c): shutil.rmtree("build") if os.path.isdir("dist"): shutil.rmtree("dist") - run("docker-compose --profile all rm -s -f") + run("docker compose --profile all rm -s -f") @task From eb95bd30cf9aad4dd186a84982b20970755a6232 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 8 Aug 2024 16:57:09 +0300 Subject: [PATCH 08/78] Revert port changes --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0222164332..97d73773ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6372/0" +default_redis_url = "redis://localhost:6379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" From 2c50adcefda8142afe461610d123f1f74f825122 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 9 Aug 2024 14:56:05 +0300 Subject: [PATCH 09/78] Initial take on Sentinel support --- tests/conftest.py | 5 +++-- tests/test_cache.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 97d73773ba..4efcfa4d3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -410,7 +410,7 @@ def sslclient(request): @pytest.fixture() -def sentinel_setup(local_cache, request): +def sentinel_setup(cache, request): sentinel_ips = request.config.getoption("--sentinels") sentinel_endpoints = [ (ip.strip(), int(port.strip())) @@ -420,7 +420,8 @@ def sentinel_setup(local_cache, request): sentinel = Sentinel( sentinel_endpoints, socket_timeout=0.1, - client_cache=local_cache, + use_cache=cache, + cache=cache, protocol=3, **kwargs, ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 96ca68a819..f44ff4e955 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -488,3 +488,21 @@ def test_cache_flushed_on_server_flush(self, r, cache): assert r.flushall() assert r.get("foo") is None assert cache.currsize == 0 + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSentinelCache: + def test_get_from_cache(self, cache, master): + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # send any command to redis (process invalidation in background) + master.ping() + # the command is not in the local cache anymore + assert cache.get(("GET", "foo")) is None + # get key from redis + assert master.get("foo") == b"barbar" From 6f582a07d1170904e5c9286033c2e77ce6de52eb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 9 Aug 2024 17:08:38 +0300 Subject: [PATCH 10/78] Remove keys option after usage --- redis/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/redis/client.py b/redis/client.py index 532c48b646..b3c42b3c20 100755 --- a/redis/client.py +++ b/redis/client.py @@ -594,6 +594,9 @@ def parse_response(self, connection, command_name, **options): if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + options.pop("keys", None) + if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response From 59fe379bbca107b159143299f1c15e121a640589 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 9 Aug 2024 17:15:50 +0300 Subject: [PATCH 11/78] Added condition to remove keys entry on async --- redis/asyncio/client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 5d93c83b12..696431d4c8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -641,6 +641,9 @@ async def parse_response( if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + options.pop("keys", None) + if command_name in self.response_callbacks: # Mypy bug: https://github.com/python/mypy/issues/10977 command_name = cast(str, command_name) From eaeef1293974ffb62b1f064a28a99a31cf09fe17 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 9 Aug 2024 17:23:49 +0300 Subject: [PATCH 12/78] Added same keys entry removal in pipeline --- redis/client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/redis/client.py b/redis/client.py index b3c42b3c20..f5e45e581c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1421,6 +1421,8 @@ def _execute_transaction(self, connection, commands, raise_on_error) -> List: for r, cmd in zip(response, commands): if not isinstance(r, Exception): args, options = cmd + # Remove keys entry, it needs only for cache. + options.pop("keys", None) command_name = args[0] if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) From 33f656e76dc9c5dc3d85e36ec7bee02e188c93e3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 12 Aug 2024 13:15:23 +0300 Subject: [PATCH 13/78] Added caching support for Sentinel --- redis/connection.py | 4 ++- redis/sentinel.py | 6 +++- tests/conftest.py | 19 +++++++--- tests/test_cache.py | 88 ++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 102 insertions(+), 15 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 82aa4ab40e..f571fafc0b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -729,6 +729,8 @@ def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfigura self.pid = os.getpid() self._conn = conn self.retry = self._conn.retry + self.host = self._conn.host + self.port = self._conn.port self._cache = cache self._conf = conf self._current_command_hash = None @@ -770,7 +772,7 @@ def send_packed_command(self, command, check_health=True): def send_command(self, *args, **kwargs): self._process_pending_invalidations() - # If command is write command or not allowed to cache skip it. + # If command is write command or not allowed to cache, transfer control to the actual connection. if not self._conf.is_allowed_to_cache(args[0]): self._current_command_hash = None self._current_command_keys = None diff --git a/redis/sentinel.py b/redis/sentinel.py index e0437c81cd..857e831527 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -229,6 +229,7 @@ def __init__( sentinels, min_other_sentinels=0, sentinel_kwargs=None, + force_master_ip=None, **connection_kwargs, ): # if sentinel_kwargs isn't defined, use the socket_* options from @@ -245,6 +246,7 @@ def __init__( ] self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs + self._force_master_ip = force_master_ip def execute_command(self, *args, **kwargs): """ @@ -304,7 +306,9 @@ def discover_master(self, service_name): sentinel, self.sentinels[0], ) - return state["ip"], state["port"] + + ip = self._force_master_ip if self._force_master_ip is not None else state["ip"] + return ip, state["port"] error_info = "" if len(collected_errors) > 0: diff --git a/tests/conftest.py b/tests/conftest.py index 4efcfa4d3b..750ab7213e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -410,22 +410,30 @@ def sslclient(request): @pytest.fixture() -def sentinel_setup(cache, request): +def sentinel_setup(request): sentinel_ips = request.config.getoption("--sentinels") sentinel_endpoints = [ (ip.strip(), int(port.strip())) for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) ] kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} + use_cache = request.param.get("use_cache", False) + cache = request.param.get("cache", None) + cache_size = request.param.get("cache_size", 128) + cache_ttl = request.param.get("cache_ttl", 300) + force_master_ip = request.param.get("force_master_ip", None) sentinel = Sentinel( sentinel_endpoints, + force_master_ip=force_master_ip, socket_timeout=0.1, - use_cache=cache, + use_cache=use_cache, cache=cache, + cache_ttl=cache_ttl, + cache_size=cache_size, protocol=3, **kwargs, ) - yield sentinel + yield sentinel, cache for s in sentinel.sentinels: s.close() @@ -433,8 +441,9 @@ def sentinel_setup(cache, request): @pytest.fixture() def master(request, sentinel_setup): master_service = request.config.getoption("--master-service") - master = sentinel_setup.master_for(master_service) - yield master + sentinel, cache = sentinel_setup + master = sentinel.master_for(master_service) + yield master, cache master.close() diff --git a/tests/test_cache.py b/tests/test_cache.py index f44ff4e955..c12b2ea744 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -5,7 +5,6 @@ from cachetools import TTLCache, LRUCache, LFUCache import redis -from redis import Redis, RedisCluster from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client @@ -113,8 +112,6 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): threading.Thread(target=r2.set, args=("bar", "bar")).start() # Wait for health check time.sleep(2) - # Trigger object destructor to shutdown health check thread - del r # Make sure that value was invalidated assert cache.get(("GET", "foo")) is None assert cache.get(("GET", "bar")) is None @@ -492,7 +489,14 @@ def test_cache_flushed_on_server_flush(self, r, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster class TestSentinelCache: - def test_get_from_cache(self, cache, master): + @pytest.mark.parametrize( + "sentinel_setup", + [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_get_from_cache(self, master, cache): + master, cache = master master.set("foo", "bar") # get key from redis and save in local cache assert master.get("foo") == b"bar" @@ -500,9 +504,77 @@ def test_get_from_cache(self, cache, master): assert cache.get(("GET", "foo")) == b"bar" # change key in redis (cause invalidation) master.set("foo", "barbar") - # send any command to redis (process invalidation in background) - master.ping() - # the command is not in the local cache anymore - assert cache.get(("GET", "foo")) is None # get key from redis assert master.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_get_from_cache_multithreaded(self, master, cache): + master, cache = master + # Running commands over two threads + threading.Thread(target=set_get, args=(master, "foo", "bar")).start() + threading.Thread(target=set_get, args=(master, "bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that both values was cached. + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + + # Running commands over two threads + threading.Thread(target=set_get, args=(master, "foo", "baz")).start() + threading.Thread(target=set_get, args=(master, "bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that new values was cached. + assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(("GET", "bar")) == b"bar" + + @pytest.mark.parametrize( + "sentinel_setup", + [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache(self, master, cache): + master, cache = master + # add key to redis + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + master.set("foo", "barbar") + # Wait for health check + time.sleep(2) + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize( + "sentinel_setup", + [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_cache_clears_on_disconnect(self, master, cache): + master, cache = master + # add key to redis + master.set("foo", "bar") + # get key from redis and save in local cache + assert master.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # Force disconnection + master.connection_pool.get_connection('_').disconnect() + # Make sure cache is empty + assert cache.currsize == 0 From e77cb60dcd12f3afa8526aff883adc83c0db892a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 13 Aug 2024 11:49:52 +0300 Subject: [PATCH 14/78] Added locking when accesing cache object --- redis/connection.py | 106 +++++++++++++++++++++++++------------------- tests/test_cache.py | 106 +++++++++++++++++++++++++++++++++----------- 2 files changed, 139 insertions(+), 73 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index f571fafc0b..9e9953c2cb 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -725,13 +725,14 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): - def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration): + def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration, cache_lock: threading.Lock): self.pid = os.getpid() self._conn = conn self.retry = self._conn.retry self.host = self._conn.host self.port = self._conn.port self._cache = cache + self._cache_lock = cache_lock self._conf = conf self._current_command_hash = None self._current_command_keys = None @@ -758,7 +759,8 @@ def on_connect(self): self._conn.on_connect() def disconnect(self, *args): - self._cache.clear() + with self._cache_lock: + self._cache.clear() self._conn.disconnect(*args) def check_health(self): @@ -789,12 +791,13 @@ def send_command(self, *args, **kwargs): if not isinstance(self._current_command_keys, list): raise TypeError("Cache keys must be a list.") - # If current command reply already cached prevent sending data over socket. - if self._cache.get(self._current_command_hash): - return + with self._cache_lock: + # If current command reply already cached prevent sending data over socket. + if self._cache.get(self._current_command_hash): + return - # Set temporary entry as a status to prevent race condition from another connection. - self._cache[self._current_command_hash] = "caching-in-progress" + # Set temporary entry as a status to prevent race condition from another connection. + self._cache[self._current_command_hash] = "caching-in-progress" # Send command over socket only if it's allowed read-only command that not yet cached. self._conn.send_command(*args, **kwargs) @@ -803,12 +806,13 @@ def can_read(self, timeout=0): return self._conn.can_read(timeout) def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): - # Check if command response exists in a cache and it's not in progress. - if ( - self._current_command_hash in self._cache - and self._cache[self._current_command_hash] != "caching-in-progress" - ): - return self._cache[self._current_command_hash] + with self._cache_lock: + # Check if command response exists in a cache and it's not in progress. + if ( + self._current_command_hash in self._cache + and self._cache[self._current_command_hash] != "caching-in-progress" + ): + return self._cache[self._current_command_hash] response = self._conn.read_response( disable_decoding=disable_decoding, @@ -816,27 +820,28 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus push_request=push_request ) - # If response is None prevent from caching and remove temporary cache entry. - if response is None: - self._cache.pop(self._current_command_hash) - return response - # Prevent not-allowed command from caching. - elif self._current_command_hash is None: - return response - - # Create separate mapping for keys or add current response to associated keys. - for key in self._current_command_keys: - if key in self._keys_mapping: - if self._current_command_hash not in self._keys_mapping[key]: - self._keys_mapping[key].append(self._current_command_hash) - else: - self._keys_mapping[key] = [self._current_command_hash] + with self._cache_lock: + # If response is None prevent from caching and remove temporary cache entry. + if response is None: + self._cache.pop(self._current_command_hash) + return response + # Prevent not-allowed command from caching. + elif self._current_command_hash is None: + return response + + # Create separate mapping for keys or add current response to associated keys. + for key in self._current_command_keys: + if key in self._keys_mapping: + if self._current_command_hash not in self._keys_mapping[key]: + self._keys_mapping[key].append(self._current_command_hash) + else: + self._keys_mapping[key] = [self._current_command_hash] - cache_entry = self._cache.get(self._current_command_hash, None) + cache_entry = self._cache.get(self._current_command_hash, None) - # Cache only responses that still valid and wasn't invalidated by another connection in meantime - if cache_entry is not None: - self._cache[self._current_command_hash] = response + # Cache only responses that still valid and wasn't invalidated by another connection in meantime + if cache_entry is not None: + self._cache[self._current_command_hash] = response return response @@ -864,18 +869,19 @@ def _process_pending_invalidations(self): def _on_invalidation_callback( self, data: List[Union[str, Optional[List[str]]]] ): - # Flush cache when DB flushed on server-side - if data[1] is None: - self._cache.clear() - else: - for key in data[1]: - normalized_key = ensure_string(key) - if normalized_key in self._keys_mapping: - # Make sure that all command responses associated with this key will be deleted - for cache_key in self._keys_mapping[normalized_key]: - self._cache.pop(cache_key) - # Removes key from mapping cache - self._keys_mapping.pop(normalized_key) + with self._cache_lock: + # Flush cache when DB flushed on server-side + if data[1] is None: + self._cache.clear() + else: + for key in data[1]: + normalized_key = ensure_string(key) + if normalized_key in self._keys_mapping: + # Make sure that all command responses associated with this key will be deleted + for cache_key in self._keys_mapping[normalized_key]: + self._cache.pop(cache_key) + # Removes key from mapping cache + self._keys_mapping.pop(normalized_key) class SSLConnection(Connection): @@ -1238,6 +1244,7 @@ def __init__( self.max_connections = max_connections self._cache = None self._cache_conf = None + self.cache_lock = None self.scheduler = None if connection_kwargs.get("use_cache"): @@ -1245,6 +1252,7 @@ def __init__( raise RedisError("Client caching is only supported with RESP version 3") self._cache_conf = CacheConfiguration(**self.connection_kwargs) + self._cache_lock = threading.Lock() cache = self.connection_kwargs.get("cache") if cache is not None: @@ -1399,7 +1407,12 @@ def make_connection(self) -> "ConnectionInterface": self._created_connections += 1 if self._cache is not None and self._cache_conf is not None: - return CacheProxyConnection(self.connection_class(**self.connection_kwargs), self._cache, self._cache_conf) + return CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self._cache, + self._cache_conf, + self._cache_lock + ) return self.connection_class(**self.connection_kwargs) @@ -1547,7 +1560,8 @@ def make_connection(self): connection = CacheProxyConnection( self.connection_class(**self.connection_kwargs), self._cache, - self._cache_conf + self._cache_conf, + self._cache_lock ) else: connection = self.connection_class(**self.connection_kwargs) diff --git a/tests/test_cache.py b/tests/test_cache.py index c12b2ea744..56c1e74563 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -57,19 +57,30 @@ def test_get_from_cache(self, r, r2, cache): def test_get_from_cache_multithreaded(self, r, cache): r, cache = r # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "bar")).start() - threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + threading.Thread(target=r.set("foo", "bar")).start() + threading.Thread(target=r.set("bar", "foo")).start() # Wait for command execution to be finished time.sleep(0.1) - # Make sure that both values was cached. + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that responses was cached. assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "bar")) == b"foo" - # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "baz")).start() - threading.Thread(target=set_get, args=(r, "bar", "bar")).start() + threading.Thread(target=r.set("foo", "baz")).start() + threading.Thread(target=r.set("bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) @@ -100,16 +111,21 @@ def test_health_check_invalidate_cache(self, r, r2, cache): def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): r, cache = r # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "bar")).start() - threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + threading.Thread(target=r.set("foo", "bar")).start() + threading.Thread(target=r.set("bar", "foo")).start() + # Wait for command execution to be finished + time.sleep(0.1) + # get keys from server + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) # get key from local cache assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "bar")) == b"foo" # change key in redis (cause invalidation) - threading.Thread(target=r2.set, args=("foo", "baz")).start() - threading.Thread(target=r2.set, args=("bar", "bar")).start() + threading.Thread(target=r2.set("foo", "baz")).start() + threading.Thread(target=r2.set("bar", "bar")).start() # Wait for health check time.sleep(2) # Make sure that value was invalidated @@ -283,12 +299,18 @@ def test_get_from_cache(self, r, r2): assert cache.get(("GET", "foo")) == b"barbar" @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) - @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, r, r2, cache): + @pytest.mark.onlycluster + def test_get_from_cache_multithreaded(self, r, cache): r, cache = r # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "bar")).start() - threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + threading.Thread(target=r.set("foo", "bar")).start() + threading.Thread(target=r.set("bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) @@ -298,8 +320,14 @@ def test_get_from_cache_multithreaded(self, r, r2, cache): assert cache.get(("GET", "bar")) == b"foo" # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "baz")).start() - threading.Thread(target=set_get, args=(r, "bar", "bar")).start() + threading.Thread(target=r.set("foo", "baz")).start() + threading.Thread(target=r.set("bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) @@ -309,7 +337,7 @@ def test_get_from_cache_multithreaded(self, r, r2, cache): assert cache.get(("GET", "bar")) == b"bar" @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) - @pytest.mark.onlynoncluster + @pytest.mark.onlycluster def test_health_check_invalidate_cache(self, r, r2, cache): r, cache = r # add key to redis @@ -326,20 +354,23 @@ def test_health_check_invalidate_cache(self, r, r2, cache): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) - @pytest.mark.onlynoncluster + @pytest.mark.onlycluster def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): r, cache = r # Running commands over two threads - threading.Thread(target=set_get, args=(r, "foo", "bar")).start() - threading.Thread(target=set_get, args=(r, "bar", "foo")).start() + threading.Thread(target=r.set("foo", "bar")).start() + threading.Thread(target=r.set("bar", "foo")).start() # Wait for command execution to be finished time.sleep(0.1) + # get keys from server + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() # get key from local cache assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "bar")) == b"foo" # change key in redis (cause invalidation) - threading.Thread(target=r2.set, args=("foo", "baz")).start() - threading.Thread(target=r2.set, args=("bar", "bar")).start() + threading.Thread(target=r.set("foo", "baz")).start() + threading.Thread(target=r.set("bar", "bar")).start() # Wait for health check time.sleep(2) # Make sure that value was invalidated @@ -347,7 +378,7 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): assert cache.get(("GET", "bar")) is None @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) - @pytest.mark.onlynoncluster + @pytest.mark.onlycluster def test_cache_clears_on_disconnect(self, r, r2, cache): r, cache = r # add key to redis @@ -366,6 +397,7 @@ def test_cache_clears_on_disconnect(self, r, r2, cache): [{"cache": LRUCache(3), "use_cache": True}], indirect=True, ) + @pytest.mark.onlycluster def test_cache_lru_eviction(self, r, cache): r, cache = r # add 3 keys to redis @@ -387,6 +419,7 @@ def test_cache_lru_eviction(self, r, cache): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + @pytest.mark.onlycluster def test_cache_ttl(self, r, cache): r, cache = r # add key to redis @@ -405,6 +438,7 @@ def test_cache_ttl(self, r, cache): [{"cache": LFUCache(3), "use_cache": True}], indirect=True, ) + @pytest.mark.onlycluster def test_cache_lfu_eviction(self, r, cache): r, cache = r # add 3 keys to redis @@ -432,6 +466,7 @@ def test_cache_lfu_eviction(self, r, cache): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlycluster def test_cache_ignore_not_allowed_command(self, r): r, cache = r # add fields to hash @@ -445,6 +480,7 @@ def test_cache_ignore_not_allowed_command(self, r): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlycluster def test_cache_invalidate_all_related_responses(self, r, cache): r, cache = r # Add keys @@ -466,6 +502,7 @@ def test_cache_invalidate_all_related_responses(self, r, cache): [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True, ) + @pytest.mark.onlycluster def test_cache_flushed_on_server_flush(self, r, cache): r, cache = r # Add keys @@ -517,9 +554,17 @@ def test_get_from_cache(self, master, cache): @pytest.mark.onlynoncluster def test_get_from_cache_multithreaded(self, master, cache): master, cache = master + + # Running commands over two threads + threading.Thread(target=master.set("foo", "bar")).start() + threading.Thread(target=master.set("bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + # Running commands over two threads - threading.Thread(target=set_get, args=(master, "foo", "bar")).start() - threading.Thread(target=set_get, args=(master, "bar", "foo")).start() + threading.Thread(target=master.get("foo")).start() + threading.Thread(target=master.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) @@ -529,8 +574,15 @@ def test_get_from_cache_multithreaded(self, master, cache): assert cache.get(("GET", "bar")) == b"foo" # Running commands over two threads - threading.Thread(target=set_get, args=(master, "foo", "baz")).start() - threading.Thread(target=set_get, args=(master, "bar", "bar")).start() + threading.Thread(target=master.set("foo", "baz")).start() + threading.Thread(target=master.set("bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Running commands over two threads + threading.Thread(target=master.get("foo")).start() + threading.Thread(target=master.get("bar")).start() # Wait for command execution to be finished time.sleep(0.1) From 2a14e13bd577dbae5e1fe456487dd535fd6db5cd Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 13 Aug 2024 12:52:35 +0300 Subject: [PATCH 15/78] Rmoved keys option from options --- redis/asyncio/cluster.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index cbceccf401..4e82e5448f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -1034,6 +1034,9 @@ async def parse_response( if EMPTY_RESPONSE in kwargs: kwargs.pop(EMPTY_RESPONSE) + # Remove keys entry, it needs only for cache. + kwargs.pop("keys", None) + # Return response if command in self.response_callbacks: return self.response_callbacks[command](response, **kwargs) From 64fb176aa88c51531d36fa99d9dcb9061dc66b3f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 15 Aug 2024 16:47:15 +0300 Subject: [PATCH 16/78] Removed redundant entities --- redis/cache.py | 136 +------------------------------------------------ 1 file changed, 1 insertion(+), 135 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index c79d5af6a9..9608afbce6 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -111,138 +111,4 @@ def is_exceeds_max_size(self, count: int) -> bool: return count > self._max_size def is_allowed_to_cache(self, command: str) -> bool: - return command in self.DEFAULT_ALLOW_LIST - - -def ensure_string(key): - if isinstance(key, bytes): - return key.decode('utf-8') - elif isinstance(key, str): - return key - else: - raise TypeError("Key must be either a string or bytes") - - -class CacheMixin: - def __init__(self, - use_cache: bool, - connection_pool: "ConnectionPool", - cache: Optional[Cache] = None, - cache_size: int = 128, - cache_ttl: int = 300, - ) -> None: - self.use_cache = use_cache - if not use_cache: - return - if cache is not None: - self.cache = cache - else: - self.cache = TTLCache(maxsize=cache_size, ttl=cache_ttl) - self.keys_mapping = LRUCache(maxsize=10000) - self.wrap_connection_pool(connection_pool) - self.connections = [] - - def cached_call(self, - func: Callable[..., ResponseT], - *args, - **options) -> ResponseT: - if not self.use_cache: - return func(*args, **options) - - print(f'Cached call with args {args} and options {options}') - - keys = None - if 'keys' in options: - keys = options['keys'] - if not isinstance(keys, list): - raise TypeError("Cache keys must be a list.") - if not keys: - return func(*args, **options) - print(f'keys {keys}') - - cache_key = hashkey(*args) - - for conn in self.connections: - conn.process_invalidation_messages() - - for key in keys: - if key in self.keys_mapping: - if cache_key not in self.keys_mapping[key]: - self.keys_mapping[key].append(cache_key) - else: - self.keys_mapping[key] = [cache_key] - - if cache_key in self.cache: - result = self.cache[cache_key] - print(f'Cached call for {args} yields cached result {result}') - return result - - result = func(*args, **options) - self.cache[cache_key] = result - print(f'Cached call for {args} yields computed result {result}') - return result - - def get_cache_entry(self, *args: Any) -> Any: - cache_key = hashkey(*args) - return self.cache.get(cache_key, None) - - def invalidate_cache_entry(self, *args: Any) -> None: - cache_key = hashkey(*args) - if cache_key in self.cache: - self.cache.pop(cache_key) - - def wrap_connection_pool(self, connection_pool: "ConnectionPool"): - if not self.use_cache: - return - if connection_pool is None: - return - original_maker = connection_pool.make_connection - connection_pool.make_connection = lambda: self._make_connection(original_maker) - - def _make_connection(self, original_maker: Callable[[], "Connection"]): - conn = original_maker() - original_disconnect = conn.disconnect - conn.disconnect = lambda: self._wrapped_disconnect(conn, original_disconnect) - self.add_connection(conn) - return conn - - def _wrapped_disconnect(self, connection: "Connection", - original_disconnect: Callable[[], NoReturn]): - original_disconnect() - self.remove_connection(connection) - - def add_connection(self, conn): - print(f'Tracking connection {conn} {id(conn)}') - conn.register_connect_callback(self._on_connect) - self.connections.append(conn) - - def _on_connect(self, conn): - conn.send_command("CLIENT", "TRACKING", "ON") - response = conn.read_response() - print(f"Client tracking response {response}") - conn._parser.set_invalidation_push_handler(self._cache_invalidation_process) - - def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] - ) -> None: - """ - Invalidate (delete) all redis commands associated with a specific key. - `data` is a list of strings, where the first string is the invalidation message - and the second string is the list of keys to invalidate. - (if the list of keys is None, then all keys are invalidated) - """ - print(f'Invalidation {data}') - if data[1] is None: - self.cache.clear() - else: - for key in data[1]: - normalized_key = ensure_string(key) - print(f'Invalidating normalized key {normalized_key}') - if normalized_key in self.keys_mapping: - for cache_key in self.keys_mapping[normalized_key]: - print(f'Invalidating cache key {cache_key}') - self.cache.pop(cache_key) - - def remove_connection(self, conn): - print(f'Untracking connection {conn} {id(conn)}') - self.connections.remove(conn) + return command in self.DEFAULT_ALLOW_LIST \ No newline at end of file From 47e6c7abfe8fde67067f1ce7778c2102a9f65a7b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 16 Aug 2024 14:51:10 +0300 Subject: [PATCH 17/78] Added cache support for SSLConnection --- redis/connection.py | 4 +++- redis/retry.py | 42 +++++++++++++++++++++++++++++++++++++++++- tests/conftest.py | 15 ++++++++++++++- tests/test_cache.py | 28 ++++++++++++++++++++++++++++ tests/test_cluster.py | 6 +++--- 5 files changed, 89 insertions(+), 6 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 9e9953c2cb..ebb82f7410 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -863,7 +863,9 @@ def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) def _process_pending_invalidations(self): - while self.can_read(): + while self.retry.call_with_retry_on_false( + lambda: self.can_read() + ): self._conn.read_response(push_request=True) def _on_invalidation_callback( diff --git a/redis/retry.py b/redis/retry.py index 03fd973c4c..0f563e344a 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,6 +1,7 @@ import socket +import time from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Optional from redis.exceptions import ConnectionError, TimeoutError @@ -68,3 +69,42 @@ def call_with_retry( backoff = self._backoff.compute(failures) if backoff > 0: sleep(backoff) + + def call_with_retry_on_false( + self, + do: Callable[[], T], + on_false: Optional[Callable[[], T]] = None, + max_retries: Optional[int] = 3, + timeout: Optional[float] = 0, + exponent: Optional[int] = 2, + ) -> bool: + """ + Execute an operation that returns boolean value with retry + logic in case if false value been returned. + `do`: the operation to call. Expects no argument. + `on_false`: Callback to be executed on retry fail. + """ + res = do() + + if res: + return res + + if on_false is not None: + on_false() + + if max_retries > 0: + if timeout > 0: + time.sleep(timeout) + + return self.call_with_retry_on_false( + do, + on_false, + max_retries - 1, + timeout * exponent, + exponent + ) + + return False + + + diff --git a/tests/conftest.py b/tests/conftest.py index 750ab7213e..427c05673c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,10 @@ from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff -from redis.connection import Connection, parse_url +from redis.connection import Connection, parse_url, SSLConnection from redis.exceptions import RedisClusterException from redis.retry import Retry +from tests.ssl_utils import get_ssl_filename REDIS_INFO = {} default_redis_url = "redis://localhost:6379/0" @@ -323,6 +324,18 @@ def _get_client( cluster_mode = REDIS_INFO["cluster_enabled"] if not cluster_mode: url_options = parse_url(redis_url) + connection_class = Connection + ssl = kwargs.pop("ssl", False) + if ssl: + connection_class = SSLConnection + kwargs["ssl_certfile"] = get_ssl_filename("client-cert.pem") + kwargs["ssl_keyfile"] = get_ssl_filename("client-key.pem") + # When you try to assign "required" as single string, it assigns tuple instead of string. + # Probably some reserved keyword, I can't explain how does it work -_- + kwargs["ssl_cert_reqs"] = "require"+"d" + kwargs["ssl_ca_certs"] = get_ssl_filename("ca-cert.pem") + kwargs["port"] = 6666 + kwargs["connection_class"] = connection_class url_options.update(kwargs) pool = redis.ConnectionPool(**url_options) client = cls(connection_pool=pool) diff --git a/tests/test_cache.py b/tests/test_cache.py index 56c1e74563..1528504a01 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -15,11 +15,13 @@ def r(request): cache = request.param.get("cache") kwargs = request.param.get("kwargs", {}) protocol = request.param.get("protocol", 3) + ssl = request.param.get("ssl", False) single_connection_client = request.param.get("single_connection_client", False) with _get_client( redis.Redis, request, protocol=protocol, + ssl=ssl, single_connection_client=single_connection_client, use_cache=use_cache, cache=cache, @@ -630,3 +632,29 @@ def test_cache_clears_on_disconnect(self, master, cache): master.connection_pool.get_connection('_').disconnect() # Make sure cache is empty assert cache.currsize == 0 + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster +class TestSSLCache: + @pytest.mark.parametrize("r", [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], indirect=True) + @pytest.mark.onlynoncluster + def test_get_from_cache(self, r, r2, cache): + r, cache = r + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + assert r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 229e0fc6e6..4e984ffd44 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -643,10 +643,10 @@ def parse_response_mock_third(connection, *args, **options): mocks["send_command"].assert_has_calls( [ call("READONLY"), - call("GET", "foo"), + call("GET", "foo", keys=['foo']), call("READONLY"), - call("GET", "foo"), - call("GET", "foo"), + call("GET", "foo", keys=['foo']), + call("GET", "foo", keys=['foo']), ] ) From eec44bd935b03ab304f15ed6193b5b6ec57b4d47 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 16 Aug 2024 15:05:58 +0300 Subject: [PATCH 18/78] Moved ssl argument handling to cover cluster case --- tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 427c05673c..2615a3c3fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from tests.ssl_utils import get_ssl_filename REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/0" +default_redis_url = "redis://localhost:16379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" @@ -322,10 +322,10 @@ def _get_client( kwargs["protocol"] = request.config.getoption("--protocol") cluster_mode = REDIS_INFO["cluster_enabled"] + ssl = kwargs.pop("ssl", False) if not cluster_mode: url_options = parse_url(redis_url) connection_class = Connection - ssl = kwargs.pop("ssl", False) if ssl: connection_class = SSLConnection kwargs["ssl_certfile"] = get_ssl_filename("client-cert.pem") From 5422955426b2c2ad6f2e0419fc9df72f0172a000 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 16 Aug 2024 15:12:24 +0300 Subject: [PATCH 19/78] Revert local test changes --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2615a3c3fb..92b8a7caa5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from tests.ssl_utils import get_ssl_filename REDIS_INFO = {} -default_redis_url = "redis://localhost:16379/0" +default_redis_url = "redis://localhost:6379/0" default_protocol = "2" default_redismod_url = "redis://localhost:6479" From 41190d7208ecc4389acdb2c2ee0176d59d58dcd7 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 16 Aug 2024 15:31:33 +0300 Subject: [PATCH 20/78] Fixed bug with missing async operator --- redis/_parsers/resp3.py | 2 +- tests/test_asyncio/test_pubsub.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 0e0a6655d2..462a3da77d 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -17,7 +17,7 @@ def __init__(self, socket_read_size): self.pubsub_push_handler_func = self.handle_pubsub_push_response self.invalidation_push_handler_func = None - def handle_pubsub_push_response(self, response): + async def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 19d4b1c650..13a6158b40 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -461,7 +461,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): @pytest.mark.onlynoncluster class TestPubSubRESP3Handler: - def my_handler(self, message): + async def my_handler(self, message): self.message = ["my handler", message] @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") From 6146d131206fe1c2ebfca665eefd19aac3e7f92e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 16 Aug 2024 15:41:51 +0300 Subject: [PATCH 21/78] Revert accidental changes --- redis/_parsers/resp3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 462a3da77d..0e0a6655d2 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -17,7 +17,7 @@ def __init__(self, socket_read_size): self.pubsub_push_handler_func = self.handle_pubsub_push_response self.invalidation_push_handler_func = None - async def handle_pubsub_push_response(self, response): + def handle_pubsub_push_response(self, response): logger = getLogger("push_response") logger.info("Push response: " + str(response)) return response From d184d6b66a613b854bd1482a3292228adf696e5f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 19 Aug 2024 11:51:48 +0300 Subject: [PATCH 22/78] Added API to return cache object --- redis/asyncio/client.py | 5 + redis/client.py | 3 + redis/cluster.py | 3 + redis/connection.py | 22 ++-- tests/conftest.py | 7 +- tests/test_cache.py | 281 +++++++++++++++++++++++++++++----------- 6 files changed, 227 insertions(+), 94 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 696431d4c8..b2ad2e2db8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -26,6 +26,8 @@ cast, ) +from cachetools import Cache + from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -651,6 +653,9 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response + def get_cache(self) -> Optional[Cache]: + return self.connection_pool.cache + StrictRedis = Redis diff --git a/redis/client.py b/redis/client.py index f5e45e581c..e431c0e887 100755 --- a/redis/client.py +++ b/redis/client.py @@ -601,6 +601,9 @@ def parse_response(self, connection, command_name, **options): return self.response_callbacks[command_name](response, **options) return response + def get_cache(self) -> Optional[Cache]: + return self.connection_pool.cache + StrictRedis = Redis diff --git a/redis/cluster.py b/redis/cluster.py index 9d135a3c0c..d74019c644 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1682,6 +1682,9 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port + def get_cache(self) -> Optional[Cache]: + return self.connection_pool.cache + class ClusterPubSub(PubSub): """ diff --git a/redis/connection.py b/redis/connection.py index ebb82f7410..d9f61b1da9 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -812,7 +812,7 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus self._current_command_hash in self._cache and self._cache[self._current_command_hash] != "caching-in-progress" ): - return self._cache[self._current_command_hash] + return copy.deepcopy(self._cache[self._current_command_hash]) response = self._conn.read_response( disable_decoding=disable_decoding, @@ -821,7 +821,7 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus ) with self._cache_lock: - # If response is None prevent from caching and remove temporary cache entry. + # If response is None prevent from caching. if response is None: self._cache.pop(self._current_command_hash) return response @@ -839,7 +839,7 @@ def read_response(self, disable_decoding=False, *, disconnect_on_error=True, pus cache_entry = self._cache.get(self._current_command_hash, None) - # Cache only responses that still valid and wasn't invalidated by another connection in meantime + # Cache only responses that still valid and wasn't invalidated by another connection in meantime. if cache_entry is not None: self._cache[self._current_command_hash] = response @@ -1244,7 +1244,7 @@ def __init__( self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections - self._cache = None + self.cache = None self._cache_conf = None self.cache_lock = None self.scheduler = None @@ -1258,9 +1258,9 @@ def __init__( cache = self.connection_kwargs.get("cache") if cache is not None: - self._cache = cache + self.cache = cache else: - self._cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) + self.cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) self.scheduler = BackgroundScheduler() self.scheduler.add_job(self._perform_health_check, "interval", seconds=2, id="cache_health_check") @@ -1378,7 +1378,7 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self._cache is None: + if connection.can_read() and self.cache is None: raise ConnectionError("Connection has data") except (ConnectionError, OSError): connection.disconnect() @@ -1408,10 +1408,10 @@ def make_connection(self) -> "ConnectionInterface": raise ConnectionError("Too many connections") self._created_connections += 1 - if self._cache is not None and self._cache_conf is not None: + if self.cache is not None and self._cache_conf is not None: return CacheProxyConnection( self.connection_class(**self.connection_kwargs), - self._cache, + self.cache, self._cache_conf, self._cache_lock ) @@ -1558,10 +1558,10 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self._cache is not None and self._cache_conf is not None: + if self.cache is not None and self._cache_conf is not None: connection = CacheProxyConnection( self.connection_class(**self.connection_kwargs), - self._cache, + self.cache, self._cache_conf, self._cache_lock ) diff --git a/tests/conftest.py b/tests/conftest.py index 92b8a7caa5..aacb2da5e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -446,7 +446,7 @@ def sentinel_setup(request): protocol=3, **kwargs, ) - yield sentinel, cache + yield sentinel for s in sentinel.sentinels: s.close() @@ -454,9 +454,8 @@ def sentinel_setup(request): @pytest.fixture() def master(request, sentinel_setup): master_service = request.config.getoption("--master-service") - sentinel, cache = sentinel_setup - master = sentinel.master_for(master_service) - yield master, cache + master = sentinel_setup.master_for(master_service) + yield master master.close() diff --git a/tests/test_cache.py b/tests/test_cache.py index 1528504a01..843d5ae138 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -27,20 +27,22 @@ def r(request): cache=cache, **kwargs, ) as client: - yield client, cache + yield client def set_get(client, key, value): client.set(key, value) return client.get(key) - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") class TestCache: - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2, cache): - r, cache = r + def test_get_from_cache(self, r, r2): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -54,10 +56,13 @@ def test_get_from_cache(self, r, r2, cache): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, r, cache): - r, cache = r + def test_get_from_cache_multithreaded(self, r): + cache = r.get_cache() # Running commands over two threads threading.Thread(target=r.set("foo", "bar")).start() threading.Thread(target=r.set("bar", "foo")).start() @@ -91,10 +96,27 @@ def test_get_from_cache_multithreaded(self, r, cache): assert cache.get(("GET", "foo")) == b"baz" assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], indirect=True) @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache(self, r, r2, cache): - r, cache = r + def test_prevent_race_condition_from_multiple_threads(self, r, cache): + cache = r.get_cache() + + # Set initial key. + assert r.set("foo", "bar") + + # Running concurrent commands over two threads to override same key. + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=set_get, args=(r, "foo", "baz")).start() + assert cache.get(("GET", "foo")) == b"bar" + + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache(self, r, r2): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -110,8 +132,8 @@ def test_health_check_invalidate_cache(self, r, r2, cache): @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): - r, cache = r + def test_health_check_invalidate_cache_multithreaded(self, r, r2): + cache = r.get_cache() # Running commands over two threads threading.Thread(target=r.set("foo", "bar")).start() threading.Thread(target=r.set("bar", "foo")).start() @@ -134,10 +156,13 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): assert cache.get(("GET", "foo")) is None assert cache.get(("GET", "bar")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_cache_clears_on_disconnect(self, r, r2, cache): - r, cache = r + def test_cache_clears_on_disconnect(self, r, cache): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -149,14 +174,13 @@ def test_cache_clears_on_disconnect(self, r, r2, cache): # Make sure cache is empty assert cache.currsize == 0 - @pytest.mark.parametrize( - "r", - [{"cache": LRUCache(3), "use_cache": True}], - indirect=True, - ) + @pytest.mark.parametrize("r", [ + {"cache": LRUCache(3), "use_cache": True, "single_connection_client": True}, + {"cache": LRUCache(3), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster def test_cache_lru_eviction(self, r, cache): - r, cache = r + cache = r.get_cache() # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -175,10 +199,13 @@ def test_cache_lru_eviction(self, r, cache): # the first key is not in the local cache anymore assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 1), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 1), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_cache_ttl(self, r, cache): - r, cache = r + def test_cache_ttl(self, r): + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -190,14 +217,13 @@ def test_cache_ttl(self, r, cache): # the key is not in the local cache anymore assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize( - "r", - [{"cache": LFUCache(3), "use_cache": True}], - indirect=True, - ) + @pytest.mark.parametrize("r", [ + {"cache": LFUCache(3), "use_cache": True, "single_connection_client": True}, + {"cache": LFUCache(3), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_cache_lfu_eviction(self, r, cache): - r, cache = r + def test_cache_lfu_eviction(self, r): + cache = r.get_cache() # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -218,35 +244,39 @@ def test_cache_lfu_eviction(self, r, cache): assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "foo2")) is None - @pytest.mark.parametrize( - "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], - indirect=True, - ) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster def test_cache_ignore_not_allowed_command(self, r): - r, cache = r + cache = r.get_cache() # add fields to hash assert r.hset("foo", "bar", "baz") # get random field assert r.hrandfield("foo") == b"bar" assert cache.get(("HRANDFIELD", "foo")) is None - @pytest.mark.parametrize( - "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], - indirect=True, - ) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_cache_invalidate_all_related_responses(self, r, cache): - r, cache = r + def test_cache_invalidate_all_related_responses(self, r): + cache = r.get_cache() # Add keys assert r.set("foo", "bar") assert r.set("bar", "foo") + res = r.mget("foo", "bar") # Make sure that replies was cached - assert r.mget("foo", "bar") == [b"bar", b"foo"] - assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"] + assert res == [b"bar", b"foo"] + assert cache.get(("MGET", "foo", "bar")) == res + + # Make sure that objects are immutable. + another_res = r.mget("foo", "bar") + res.append(b"baz") + assert another_res != res # Invalidate one of the keys and make sure that all associated cached entries was removed assert r.set("foo", "baz") @@ -254,14 +284,13 @@ def test_cache_invalidate_all_related_responses(self, r, cache): assert cache.get(("MGET", "foo", "bar")) is None assert cache.get(("GET", "foo")) == b"baz" - @pytest.mark.parametrize( - "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], - indirect=True, - ) + @pytest.mark.parametrize("r", [ + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, + {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, + ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_cache_flushed_on_server_flush(self, r, cache): - r, cache = r + def test_cache_flushed_on_server_flush(self, r): + cache = r.get_cache() # Add keys assert r.set("foo", "bar") assert r.set("bar", "foo") @@ -286,7 +315,7 @@ def test_cache_flushed_on_server_flush(self, r, cache): class TestClusterCache: @pytest.mark.parametrize("r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True) def test_get_from_cache(self, r, r2): - r, cache = r + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -302,8 +331,8 @@ def test_get_from_cache(self, r, r2): @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlycluster - def test_get_from_cache_multithreaded(self, r, cache): - r, cache = r + def test_get_from_cache_multithreaded(self, r): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # Running commands over two threads threading.Thread(target=r.set("foo", "bar")).start() threading.Thread(target=r.set("bar", "foo")).start() @@ -340,8 +369,8 @@ def test_get_from_cache_multithreaded(self, r, cache): @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlycluster - def test_health_check_invalidate_cache(self, r, r2, cache): - r, cache = r + def test_health_check_invalidate_cache(self, r, r2): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -357,8 +386,8 @@ def test_health_check_invalidate_cache(self, r, r2, cache): @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlycluster - def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): - r, cache = r + def test_health_check_invalidate_cache_multithreaded(self, r, r2): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # Running commands over two threads threading.Thread(target=r.set("foo", "bar")).start() threading.Thread(target=r.set("bar", "foo")).start() @@ -381,8 +410,8 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2, cache): @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlycluster - def test_cache_clears_on_disconnect(self, r, r2, cache): - r, cache = r + def test_cache_clears_on_disconnect(self, r, r2): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -400,8 +429,8 @@ def test_cache_clears_on_disconnect(self, r, r2, cache): indirect=True, ) @pytest.mark.onlycluster - def test_cache_lru_eviction(self, r, cache): - r, cache = r + def test_cache_lru_eviction(self, r): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -422,8 +451,8 @@ def test_cache_lru_eviction(self, r, cache): @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) @pytest.mark.onlycluster - def test_cache_ttl(self, r, cache): - r, cache = r + def test_cache_ttl(self, r): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -441,8 +470,8 @@ def test_cache_ttl(self, r, cache): indirect=True, ) @pytest.mark.onlycluster - def test_cache_lfu_eviction(self, r, cache): - r, cache = r + def test_cache_lfu_eviction(self, r): + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add 3 keys to redis r.set("foo", "bar") r.set("foo2", "bar2") @@ -470,7 +499,7 @@ def test_cache_lfu_eviction(self, r, cache): ) @pytest.mark.onlycluster def test_cache_ignore_not_allowed_command(self, r): - r, cache = r + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add fields to hash assert r.hset("foo", "bar", "baz") # get random field @@ -484,7 +513,7 @@ def test_cache_ignore_not_allowed_command(self, r): ) @pytest.mark.onlycluster def test_cache_invalidate_all_related_responses(self, r, cache): - r, cache = r + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # Add keys assert r.set("foo{slot}", "bar") assert r.set("bar{slot}", "foo") @@ -506,7 +535,7 @@ def test_cache_invalidate_all_related_responses(self, r, cache): ) @pytest.mark.onlycluster def test_cache_flushed_on_server_flush(self, r, cache): - r, cache = r + cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # Add keys assert r.set("foo", "bar") assert r.set("bar", "foo") @@ -534,8 +563,8 @@ class TestSentinelCache: indirect=True, ) @pytest.mark.onlynoncluster - def test_get_from_cache(self, master, cache): - master, cache = master + def test_get_from_cache(self, master): + cache = master.get_cache() master.set("foo", "bar") # get key from redis and save in local cache assert master.get("foo") == b"bar" @@ -554,8 +583,8 @@ def test_get_from_cache(self, master, cache): indirect=True, ) @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, master, cache): - master, cache = master + def test_get_from_cache_multithreaded(self, master): + cache = master.get_cache() # Running commands over two threads threading.Thread(target=master.set("foo", "bar")).start() @@ -600,7 +629,7 @@ def test_get_from_cache_multithreaded(self, master, cache): ) @pytest.mark.onlynoncluster def test_health_check_invalidate_cache(self, master, cache): - master, cache = master + cache = master.get_cache() # add key to redis master.set("foo", "bar") # get key from redis and save in local cache @@ -621,7 +650,7 @@ def test_health_check_invalidate_cache(self, master, cache): ) @pytest.mark.onlynoncluster def test_cache_clears_on_disconnect(self, master, cache): - master, cache = master + cache = master.get_cache() # add key to redis master.set("foo", "bar") # get key from redis and save in local cache @@ -645,7 +674,7 @@ class TestSSLCache: ], indirect=True) @pytest.mark.onlynoncluster def test_get_from_cache(self, r, r2, cache): - r, cache = r + cache = r.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache @@ -658,3 +687,97 @@ def test_get_from_cache(self, r, r2, cache): assert r.get("foo") == b"barbar" # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + + @pytest.mark.parametrize("r", [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], indirect=True) + @pytest.mark.onlynoncluster + def test_get_from_cache_multithreaded(self, r): + cache = r.get_cache() + # Running commands over two threads + threading.Thread(target=r.set("foo", "bar")).start() + threading.Thread(target=r.set("bar", "foo")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that responses was cached. + assert cache.get(("GET", "foo")) == b"bar" + assert cache.get(("GET", "bar")) == b"foo" + + threading.Thread(target=r.set("foo", "baz")).start() + threading.Thread(target=r.set("bar", "bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + threading.Thread(target=r.get("foo")).start() + threading.Thread(target=r.get("bar")).start() + + # Wait for command execution to be finished + time.sleep(0.1) + + # Make sure that new values was cached. + assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(("GET", "bar")) == b"bar" + + @pytest.mark.parametrize("r", [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], indirect=True) + @pytest.mark.onlynoncluster + def test_health_check_invalidate_cache(self, r, r2): + cache = r.get_cache() + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Wait for health check + time.sleep(2) + # Make sure that value was invalidated + assert cache.get(("GET", "foo")) is None + + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], + indirect=True, + ) + @pytest.mark.onlynoncluster + def test_cache_invalidate_all_related_responses(self, r): + cache = r.get_cache() + # Add keys + assert r.set("foo", "bar") + assert r.set("bar", "foo") + + # Make sure that replies was cached + assert r.mget("foo", "bar") == [b"bar", b"foo"] + assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"] + + # Invalidate one of the keys and make sure that all associated cached entries was removed + assert r.set("foo", "baz") + assert r.get("foo") == b"baz" + assert cache.get(("MGET", "foo", "bar")) is None + assert cache.get(("GET", "foo")) == b"baz" From fe124e12dbec01fa8ac630027a111ec3ce070095 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 19 Aug 2024 14:11:27 +0300 Subject: [PATCH 23/78] Added eviction policy configuration --- redis/cache.py | 67 ++++++++++++++++++++----- redis/client.py | 3 ++ redis/cluster.py | 10 ++-- redis/connection.py | 7 ++- tests/test_cache.py | 119 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 181 insertions(+), 25 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 9608afbce6..13b7d867b4 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,19 +1,17 @@ -from typing import Callable, TypeVar, Any, NoReturn, List, Union -from typing import Optional +from abc import ABC, abstractmethod +from typing import TypeVar from enum import Enum -from cachetools import TTLCache, Cache, LRUCache -from cachetools.keys import hashkey - -from redis.typing import ResponseT +from cachetools import LRUCache, LFUCache, RRCache, Cache, TTLCache T = TypeVar('T') class EvictionPolicy(Enum): - LRU = "lru" - LFU = "lfu" - RANDOM = "random" + LRU = "LRU" + LFU = "LFU" + RANDOM = "RANDOM" + TTL = "TTL" class CacheConfiguration: @@ -97,13 +95,25 @@ class CacheConfiguration: ] def __init__(self, **kwargs): - self._max_size = kwargs.get("cache_size", 10000) - self._ttl = kwargs.get("cache_ttl", 0) - self._eviction_policy = kwargs.get("eviction_policy", self.DEFAULT_EVICTION_POLICY) + self._max_size = kwargs.get("cache_size", None) + self._ttl = kwargs.get("cache_ttl", None) + self._eviction_policy = kwargs.get("cache_eviction", None) + if self._max_size is None: + self._max_size = 10000 + if self._ttl is None: + self._ttl = 0 + if self._eviction_policy is None: + self._eviction_policy = EvictionPolicy.LRU + + if self._eviction_policy not in EvictionPolicy: + raise ValueError(f"Invalid eviction_policy {self._eviction_policy}") def get_ttl(self) -> int: return self._ttl + def get_max_size(self) -> int: + return self._max_size + def get_eviction_policy(self) -> EvictionPolicy: return self._eviction_policy @@ -111,4 +121,35 @@ def is_exceeds_max_size(self, count: int) -> bool: return count > self._max_size def is_allowed_to_cache(self, command: str) -> bool: - return command in self.DEFAULT_ALLOW_LIST \ No newline at end of file + return command in self.DEFAULT_ALLOW_LIST + + +class CacheClass(Enum): + LRU = LRUCache + LFU = LFUCache + RANDOM = RRCache + TTL = TTLCache + + +class CacheFactoryInterface(ABC): + @abstractmethod + def get_cache(self) -> Cache: + pass + + +class CacheFactory(CacheFactoryInterface): + def __init__(self, conf: CacheConfiguration): + self._conf = conf + + def get_cache(self) -> Cache: + eviction_policy = self._conf.get_eviction_policy() + cache_class = self._get_cache_class(eviction_policy).value + + if eviction_policy == EvictionPolicy.TTL: + return cache_class(self._conf.get_max_size(), self._conf.get_ttl()) + + return cache_class(self._conf.get_max_size()) + + def _get_cache_class(self, eviction_policy: EvictionPolicy) -> CacheClass: + return CacheClass[eviction_policy.value] + diff --git a/redis/client.py b/redis/client.py index e431c0e887..805fc3d9cd 100755 --- a/redis/client.py +++ b/redis/client.py @@ -15,6 +15,7 @@ _RedisCallbacksRESP3, bool_ok, ) +from redis.cache import EvictionPolicy from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -216,6 +217,7 @@ def __init__( protocol: Optional[int] = 2, use_cache: bool = False, cache: Optional[Cache] = None, + cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, ) -> None: @@ -315,6 +317,7 @@ def __init__( { "use_cache": use_cache, "cache": cache, + "cache_eviction": cache_eviction, "cache_size": cache_size, "cache_ttl": cache_ttl, } diff --git a/redis/cluster.py b/redis/cluster.py index d74019c644..9d6f5be3b3 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -11,6 +11,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff +from redis.cache import EvictionPolicy from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -505,6 +506,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, cache: Optional[Cache] = None, + cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, **kwargs, @@ -648,6 +650,7 @@ def __init__( address_remap=address_remap, use_cache=use_cache, cache=cache, + cache_eviction=cache_eviction, cache_size=cache_size, cache_ttl=cache_ttl, **kwargs, @@ -1331,6 +1334,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, cache: Optional[Cache] = None, + cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, **kwargs, @@ -1347,6 +1351,7 @@ def __init__( self.address_remap = address_remap self.use_cache = use_cache self.cache = cache + self.cache_eviction = cache_eviction self.cache_size = cache_size self.cache_ttl = cache_ttl self._moved_exception = None @@ -1494,6 +1499,7 @@ def create_redis_node(self, host, port, **kwargs): kwargs.update({"port": port}) kwargs.update({"use_cache": self.use_cache}) kwargs.update({"cache": self.cache}) + kwargs.update({"cache_eviction": self.cache_eviction}) kwargs.update({"cache_size": self.cache_size}) kwargs.update({"cache_ttl": self.cache_ttl}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) @@ -1503,6 +1509,7 @@ def create_redis_node(self, host, port, **kwargs): port=port, use_cache=self.use_cache, cache=self.cache, + cache_eviction=self.cache_eviction, cache_size=self.cache_size, cache_ttl=self.cache_ttl, **kwargs, @@ -1682,9 +1689,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]: return self.address_remap((host, port)) return host, port - def get_cache(self) -> Optional[Cache]: - return self.connection_pool.cache - class ClusterPubSub(PubSub): """ diff --git a/redis/connection.py b/redis/connection.py index d9f61b1da9..5fa76dc50f 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -15,7 +15,7 @@ from apscheduler.schedulers.background import BackgroundScheduler from cachetools import TTLCache, Cache, LRUCache from cachetools.keys import hashkey -from redis.cache import CacheConfiguration +from redis.cache import CacheConfiguration, CacheFactory from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff @@ -1246,6 +1246,7 @@ def __init__( self.max_connections = max_connections self.cache = None self._cache_conf = None + self._cache_factory = None self.cache_lock = None self.scheduler = None @@ -1260,13 +1261,15 @@ def __init__( if cache is not None: self.cache = cache else: - self.cache = TTLCache(self.connection_kwargs["cache_size"], self.connection_kwargs["cache_ttl"]) + cache_factory = CacheFactory(self._cache_conf) + self.cache = cache_factory.get_cache() self.scheduler = BackgroundScheduler() self.scheduler.add_job(self._perform_health_check, "interval", seconds=2, id="cache_health_check") self.scheduler.start() connection_kwargs.pop("use_cache", None) + connection_kwargs.pop("cache_eviction", None) connection_kwargs.pop("cache_size", None) connection_kwargs.pop("cache_ttl", None) connection_kwargs.pop("cache", None) diff --git a/tests/test_cache.py b/tests/test_cache.py index 843d5ae138..bf89971039 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -5,6 +5,7 @@ from cachetools import TTLCache, LRUCache, LFUCache import redis +from redis.cache import EvictionPolicy, CacheClass from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client @@ -13,6 +14,9 @@ def r(request): use_cache = request.param.get("use_cache", False) cache = request.param.get("cache") + cache_eviction = request.param.get("cache_eviction") + cache_size = request.param.get("cache_size") + cache_ttl = request.param.get("cache_ttl") kwargs = request.param.get("kwargs", {}) protocol = request.param.get("protocol", 3) ssl = request.param.get("ssl", False) @@ -25,6 +29,9 @@ def r(request): single_connection_client=single_connection_client, use_cache=use_cache, cache=cache, + cache_eviction=cache_eviction, + cache_size=cache_size, + cache_ttl=cache_ttl, **kwargs, ) as client: yield client @@ -41,7 +48,7 @@ class TestCache: {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster - def test_get_from_cache(self, r, r2): + def test_get_from_given_cache(self, r, r2): cache = r.get_cache() # add key to redis r.set("foo", "bar") @@ -56,6 +63,31 @@ def test_get_from_cache(self, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [ + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, + ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + def test_get_from_custom_cache(self, request, r, r2): + cache_class = CacheClass[request.node.callspec.id] + cache = r.get_cache() + assert isinstance(cache, cache_class.value) + + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + + @pytest.mark.parametrize("r", [ {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, @@ -175,8 +207,8 @@ def test_cache_clears_on_disconnect(self, r, cache): assert cache.currsize == 0 @pytest.mark.parametrize("r", [ - {"cache": LRUCache(3), "use_cache": True, "single_connection_client": True}, - {"cache": LRUCache(3), "use_cache": True, "single_connection_client": False}, + {"use_cache": True, "cache_size": 3, "single_connection_client": True}, + {"use_cache": True, "cache_size": 3, "single_connection_client": False}, ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster def test_cache_lru_eviction(self, r, cache): @@ -200,8 +232,8 @@ def test_cache_lru_eviction(self, r, cache): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 1), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 1), "use_cache": True, "single_connection_client": False}, + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_ttl": 1, "single_connection_client": True}, + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_ttl": 1, "single_connection_client": False}, ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster def test_cache_ttl(self, r): @@ -218,8 +250,8 @@ def test_cache_ttl(self, r): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize("r", [ - {"cache": LFUCache(3), "use_cache": True, "single_connection_client": True}, - {"cache": LFUCache(3), "use_cache": True, "single_connection_client": False}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 3, "single_connection_client": True}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 3, "single_connection_client": False}, ], ids=["single", "pool"], indirect=True) @pytest.mark.onlynoncluster def test_cache_lfu_eviction(self, r): @@ -329,6 +361,31 @@ def test_get_from_cache(self, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [ + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, + ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.onlycluster + def test_get_from_custom_cache(self, request, r, r2): + cache_class = CacheClass[request.node.callspec.id] + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() + assert isinstance(cache, cache_class.value) + + # add key to redis + assert r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) @pytest.mark.onlycluster def test_get_from_cache_multithreaded(self, r): @@ -577,6 +634,30 @@ def test_get_from_cache(self, master): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [ + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, + {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, + ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + def test_get_from_custom_cache(self, request, r, r2): + cache_class = CacheClass[request.node.callspec.id] + cache = r.get_cache() + assert isinstance(cache, cache_class.value) + + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize( "sentinel_setup", [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], @@ -688,6 +769,30 @@ def test_get_from_cache(self, r, r2, cache): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [ + {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300, "ssl": True}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128, "ssl": True}, + {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128, "ssl": True}, + {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128, "ssl": True}, + ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + def test_get_from_custom_cache(self, request, r, r2): + cache_class = CacheClass[request.node.callspec.id] + cache = r.get_cache() + assert isinstance(cache, cache_class.value) + + # add key to redis + r.set("foo", "bar") + # get key from redis and save in local cache + assert r.get("foo") == b"bar" + # get key from local cache + assert cache.get(("GET", "foo")) == b"bar" + # change key in redis (cause invalidation) + r2.set("foo", "barbar") + # Retrieves a new value from server and cache it + assert r.get("foo") == b"barbar" + # Make sure that new value was cached + assert cache.get(("GET", "foo")) == b"barbar" + @pytest.mark.parametrize("r", [ { "cache": TTLCache(128, 300), From 21778be9d9473b109aa32ecd637f878ff12c08d1 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 19 Aug 2024 14:33:47 +0300 Subject: [PATCH 24/78] Added mark to skip test on cluster --- tests/test_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index bf89971039..1c792e1531 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -42,6 +42,7 @@ def set_get(client, key, value): return client.get(key) @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.onlynoncluster class TestCache: @pytest.mark.parametrize("r", [ {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, @@ -69,6 +70,7 @@ def test_get_from_given_cache(self, r, r2): {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.onlynoncluster def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] cache = r.get_cache() @@ -367,7 +369,6 @@ def test_get_from_cache(self, r, r2): {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) - @pytest.mark.onlycluster def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() From 19c8f35ac5cbfa85833953f35f7d923b0bdee363 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 10:34:25 +0300 Subject: [PATCH 25/78] Removed test case that makes no sense --- redis/connection.py | 1 - tests/test_cache.py | 15 --------------- 2 files changed, 16 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 5fa76dc50f..186d1bb33d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -767,7 +767,6 @@ def check_health(self): self._conn.check_health() def send_packed_command(self, command, check_health=True): - self._process_pending_invalidations() # TODO: Investigate if it's possible to unpack command or extract keys from packed command self._conn.send_packed_command(command) diff --git a/tests/test_cache.py b/tests/test_cache.py index 1c792e1531..08d2ae7bfc 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -130,21 +130,6 @@ def test_get_from_cache_multithreaded(self, r): assert cache.get(("GET", "foo")) == b"baz" assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], indirect=True) - @pytest.mark.onlynoncluster - def test_prevent_race_condition_from_multiple_threads(self, r, cache): - cache = r.get_cache() - - # Set initial key. - assert r.set("foo", "bar") - - # Running concurrent commands over two threads to override same key. - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=set_get, args=(r, "foo", "baz")).start() - assert cache.get(("GET", "foo")) == b"bar" - @pytest.mark.parametrize("r", [ {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, ], indirect=True) From 7edb46bdb458636069b4eac87eeeb0148c03becb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 12:52:28 +0300 Subject: [PATCH 26/78] Skip tests in RESP2 --- tests/test_cache.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 08d2ae7bfc..5783d30c0b 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -7,7 +7,7 @@ import redis from redis.cache import EvictionPolicy, CacheClass from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import _get_client +from tests.conftest import _get_client, skip_if_resp_version @pytest.fixture() @@ -36,13 +36,9 @@ def r(request): ) as client: yield client - -def set_get(client, key, value): - client.set(key, value) - return client.get(key) - @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster +@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize("r", [ {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, @@ -331,6 +327,7 @@ def test_cache_flushed_on_server_flush(self, r): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster +@skip_if_resp_version(2) class TestClusterCache: @pytest.mark.parametrize("r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True) def test_get_from_cache(self, r, r2): @@ -599,6 +596,7 @@ def test_cache_flushed_on_server_flush(self, r, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster +@skip_if_resp_version(2) class TestSentinelCache: @pytest.mark.parametrize( "sentinel_setup", @@ -731,6 +729,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster +@skip_if_resp_version(2) class TestSSLCache: @pytest.mark.parametrize("r", [ { From e6ebab60891a55568e9eb3e85e8f4265a5469121 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:20:19 +0300 Subject: [PATCH 27/78] Added scheduler to dev_requirements --- dev_requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dev_requirements.txt b/dev_requirements.txt index a8da4b49cd..13373a09d1 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,6 @@ black==24.3.0 cachetools +apscheduler click==8.0.4 flake8-isort flake8 From 852c36fa8f61ea7f9c5ce0032248fe5f743ff3ed Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:21:40 +0300 Subject: [PATCH 28/78] Codestyle changes --- redis/cache.py | 3 +- redis/client.py | 2 +- redis/cluster.py | 2 +- redis/connection.py | 209 +++++++++--------- redis/retry.py | 32 +-- redis/sentinel.py | 6 +- tests/conftest.py | 2 +- tests/test_cache.py | 487 ++++++++++++++++++++++++++++++++---------- tests/test_cluster.py | 6 +- 9 files changed, 515 insertions(+), 234 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 13b7d867b4..78d8409c44 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -4,7 +4,7 @@ from cachetools import LRUCache, LFUCache, RRCache, Cache, TTLCache -T = TypeVar('T') +T = TypeVar("T") class EvictionPolicy(Enum): @@ -152,4 +152,3 @@ def get_cache(self) -> Cache: def _get_cache_class(self, eviction_policy: EvictionPolicy) -> CacheClass: return CacheClass[eviction_policy.value] - diff --git a/redis/client.py b/redis/client.py index 805fc3d9cd..3bdfa808a4 100755 --- a/redis/client.py +++ b/redis/client.py @@ -149,7 +149,7 @@ class initializer. In the case of conflicting arguments, querystring client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, - use_cache=use_cache + use_cache=use_cache, ) client.auto_close_connection_pool = True return client diff --git a/redis/cluster.py b/redis/cluster.py index 9d6f5be3b3..ace81b6679 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1322,7 +1322,7 @@ def reset(self) -> None: self.primary_to_idx.clear() -class NodesManager(): +class NodesManager: def __init__( self, startup_nodes, diff --git a/redis/connection.py b/redis/connection.py index 186d1bb33d..15f7a8d2e2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -104,9 +104,9 @@ def pack(self, *args): # output list if we're sending large values or memoryviews arg_length = len(arg) if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) ): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) @@ -176,11 +176,11 @@ def can_read(self, timeout=0): @abstractmethod def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): pass @@ -197,28 +197,28 @@ class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" def __init__( - self, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - retry_on_timeout: bool = False, - retry_on_error=SENTINEL, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - parser_class=DefaultParser, - socket_read_size: int = 65536, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "redis-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - retry: Union[Any, None] = None, - redis_connect_func: Optional[Callable[[], None]] = None, - credential_provider: Optional[CredentialProvider] = None, - protocol: Optional[int] = 2, - command_packer: Optional[Callable[[], None]] = None, + self, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error=SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "redis-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + redis_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, ): """ Initialize a new Connection. @@ -393,8 +393,8 @@ def on_connect(self): # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( - self.credential_provider - or UsernamePasswordCredentialProvider(self.username, self.password) + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() @@ -442,8 +442,8 @@ def on_connect(self): self.send_command("HELLO", self.protocol) response = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -562,11 +562,11 @@ def can_read(self, timeout=0): raise ConnectionError(f"Error while reading from {host_error}: {e.args}") def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): """Read the response from a previously sent command""" @@ -622,9 +622,9 @@ def pack_commands(self, commands): for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( - buffer_length > buffer_cutoff - or chunklen > buffer_cutoff - or isinstance(chunk, memoryview) + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) ): if pieces: output.append(SYM_EMPTY.join(pieces)) @@ -649,13 +649,13 @@ class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" def __init__( - self, - host="localhost", - port=6379, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, - **kwargs, + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, ): self.host = host self.port = int(port) @@ -677,7 +677,7 @@ def _connect(self): # socket.connect() err = None for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -717,7 +717,7 @@ def _host_error(self): def ensure_string(key): if isinstance(key, bytes): - return key.decode('utf-8') + return key.decode("utf-8") elif isinstance(key, str): return key else: @@ -725,7 +725,13 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): - def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration, cache_lock: threading.Lock): + def __init__( + self, + conn: ConnectionInterface, + cache: Cache, + conf: CacheConfiguration, + cache_lock: threading.Lock, + ): self.pid = os.getpid() self._conn = conn self.retry = self._conn.retry @@ -804,19 +810,21 @@ def send_command(self, *args, **kwargs): def can_read(self, timeout=0): return self._conn.can_read(timeout) - def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error=True, push_request=False + ): with self._cache_lock: # Check if command response exists in a cache and it's not in progress. if ( - self._current_command_hash in self._cache - and self._cache[self._current_command_hash] != "caching-in-progress" + self._current_command_hash in self._cache + and self._cache[self._current_command_hash] != "caching-in-progress" ): return copy.deepcopy(self._cache[self._current_command_hash]) response = self._conn.read_response( disable_decoding=disable_decoding, disconnect_on_error=disconnect_on_error, - push_request=push_request + push_request=push_request, ) with self._cache_lock: @@ -857,19 +865,15 @@ def _host_error(self): self._conn._host_error() def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: - conn.send_command('CLIENT', 'TRACKING', 'ON') + conn.send_command("CLIENT", "TRACKING", "ON") conn.read_response() conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) def _process_pending_invalidations(self): - while self.retry.call_with_retry_on_false( - lambda: self.can_read() - ): + while self.retry.call_with_retry_on_false(lambda: self.can_read()): self._conn.read_response(push_request=True) - def _on_invalidation_callback( - self, data: List[Union[str, Optional[List[str]]]] - ): + def _on_invalidation_callback(self, data: List[Union[str, Optional[List[str]]]]): with self._cache_lock: # Flush cache when DB flushed on server-side if data[1] is None: @@ -892,22 +896,22 @@ class SSLConnection(Connection): """ # noqa def __init__( - self, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_ca_path=None, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - **kwargs, + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_ca_path=None, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + ssl_min_version=None, + ssl_ciphers=None, + **kwargs, ): """Constructor @@ -994,9 +998,9 @@ def _wrap_socket_with_ssl(self, sock): password=self.certificate_password, ) if ( - self.ca_certs is not None - or self.ca_path is not None - or self.ca_data is not None + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None ): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data @@ -1112,9 +1116,9 @@ def to_bool(value): def parse_url(url): if not ( - url.startswith("redis://") - or url.startswith("rediss://") - or url.startswith("unix://") + url.startswith("redis://") + or url.startswith("rediss://") + or url.startswith("unix://") ): raise ValueError( "Redis URL must specify one of the following " @@ -1231,12 +1235,12 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, - connection_class=Connection, - max_connections: Optional[int] = None, - **connection_kwargs, + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): - max_connections = max_connections or 2 ** 31 + max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') @@ -1264,7 +1268,12 @@ def __init__( self.cache = cache_factory.get_cache() self.scheduler = BackgroundScheduler() - self.scheduler.add_job(self._perform_health_check, "interval", seconds=2, id="cache_health_check") + self.scheduler.add_job( + self._perform_health_check, + "interval", + seconds=2, + id="cache_health_check", + ) self.scheduler.start() connection_kwargs.pop("use_cache", None) @@ -1415,7 +1424,7 @@ def make_connection(self) -> "ConnectionInterface": self.connection_class(**self.connection_kwargs), self.cache, self._cache_conf, - self._cache_lock + self._cache_lock, ) return self.connection_class(**self.connection_kwargs) @@ -1480,7 +1489,7 @@ def _perform_health_check(self) -> None: with self._lock: while self._available_connections: conn = self._available_connections.pop() - conn.send_command('PING') + conn.send_command("PING") conn.read_response() @@ -1519,12 +1528,12 @@ class BlockingConnectionPool(ConnectionPool): """ def __init__( - self, - max_connections=50, - timeout=20, - connection_class=Connection, - queue_class=LifoQueue, - **connection_kwargs, + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout @@ -1565,7 +1574,7 @@ def make_connection(self): self.connection_class(**self.connection_kwargs), self.cache, self._cache_conf, - self._cache_lock + self._cache_lock, ) else: connection = self.connection_class(**self.connection_kwargs) diff --git a/redis/retry.py b/redis/retry.py index 0f563e344a..7fbc4039e2 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,7 +1,16 @@ import socket import time from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Tuple, + Type, + TypeVar, + Optional, +) from redis.exceptions import ConnectionError, TimeoutError @@ -71,12 +80,12 @@ def call_with_retry( sleep(backoff) def call_with_retry_on_false( - self, - do: Callable[[], T], - on_false: Optional[Callable[[], T]] = None, - max_retries: Optional[int] = 3, - timeout: Optional[float] = 0, - exponent: Optional[int] = 2, + self, + do: Callable[[], T], + on_false: Optional[Callable[[], T]] = None, + max_retries: Optional[int] = 3, + timeout: Optional[float] = 0, + exponent: Optional[int] = 2, ) -> bool: """ Execute an operation that returns boolean value with retry @@ -97,14 +106,7 @@ def call_with_retry_on_false( time.sleep(timeout) return self.call_with_retry_on_false( - do, - on_false, - max_retries - 1, - timeout * exponent, - exponent + do, on_false, max_retries - 1, timeout * exponent, exponent ) return False - - - diff --git a/redis/sentinel.py b/redis/sentinel.py index 857e831527..01e210794c 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -307,7 +307,11 @@ def discover_master(self, service_name): self.sentinels[0], ) - ip = self._force_master_ip if self._force_master_ip is not None else state["ip"] + ip = ( + self._force_master_ip + if self._force_master_ip is not None + else state["ip"] + ) return ip, state["port"] error_info = "" diff --git a/tests/conftest.py b/tests/conftest.py index aacb2da5e3..406f13bf67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -332,7 +332,7 @@ def _get_client( kwargs["ssl_keyfile"] = get_ssl_filename("client-key.pem") # When you try to assign "required" as single string, it assigns tuple instead of string. # Probably some reserved keyword, I can't explain how does it work -_- - kwargs["ssl_cert_reqs"] = "require"+"d" + kwargs["ssl_cert_reqs"] = "require" + "d" kwargs["ssl_ca_certs"] = get_ssl_filename("ca-cert.pem") kwargs["port"] = 6666 kwargs["connection_class"] = connection_class diff --git a/tests/test_cache.py b/tests/test_cache.py index 5783d30c0b..77e7d8a579 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -22,28 +22,42 @@ def r(request): ssl = request.param.get("ssl", False) single_connection_client = request.param.get("single_connection_client", False) with _get_client( - redis.Redis, - request, - protocol=protocol, - ssl=ssl, - single_connection_client=single_connection_client, - use_cache=use_cache, - cache=cache, - cache_eviction=cache_eviction, - cache_size=cache_size, - cache_ttl=cache_ttl, - **kwargs, + redis.Redis, + request, + protocol=protocol, + ssl=ssl, + single_connection_client=single_connection_client, + use_cache=use_cache, + cache=cache, + cache_eviction=cache_eviction, + cache_size=cache_size, + cache_ttl=cache_ttl, + **kwargs, ) as client: yield client + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) class TestCache: - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_get_from_given_cache(self, r, r2): cache = r.get_cache() @@ -60,12 +74,34 @@ def test_get_from_given_cache(self, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, - ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_size": 128, + "cache_ttl": 300, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LRU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.RANDOM, + "cache_size": 128, + }, + ], + ids=["TTL", "LRU", "LFU", "RANDOM"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] @@ -85,11 +121,23 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_get_from_cache_multithreaded(self, r): cache = r.get_cache() @@ -126,9 +174,17 @@ def test_get_from_cache_multithreaded(self, r): assert cache.get(("GET", "foo")) == b"baz" assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + indirect=True, + ) @pytest.mark.onlynoncluster def test_health_check_invalidate_cache(self, r, r2): cache = r.get_cache() @@ -145,7 +201,9 @@ def test_health_check_invalidate_cache(self, r, r2): # Make sure that value was invalidated assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + ) @pytest.mark.onlynoncluster def test_health_check_invalidate_cache_multithreaded(self, r, r2): cache = r.get_cache() @@ -171,10 +229,23 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2): assert cache.get(("GET", "foo")) is None assert cache.get(("GET", "bar")) is None - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_clears_on_disconnect(self, r, cache): cache = r.get_cache() @@ -185,14 +256,19 @@ def test_cache_clears_on_disconnect(self, r, cache): # get key from local cache assert cache.get(("GET", "foo")) == b"bar" # Force disconnection - r.connection_pool.get_connection('_').disconnect() + r.connection_pool.get_connection("_").disconnect() # Make sure cache is empty assert cache.currsize == 0 - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_size": 3, "single_connection_client": True}, - {"use_cache": True, "cache_size": 3, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + {"use_cache": True, "cache_size": 3, "single_connection_client": True}, + {"use_cache": True, "cache_size": 3, "single_connection_client": False}, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_lru_eviction(self, r, cache): cache = r.get_cache() @@ -214,10 +290,25 @@ def test_cache_lru_eviction(self, r, cache): # the first key is not in the local cache anymore assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_ttl": 1, "single_connection_client": True}, - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_ttl": 1, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_ttl": 1, + "single_connection_client": True, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_ttl": 1, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_ttl(self, r): cache = r.get_cache() @@ -232,10 +323,25 @@ def test_cache_ttl(self, r): # the key is not in the local cache anymore assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 3, "single_connection_client": True}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 3, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 3, + "single_connection_client": True, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 3, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_lfu_eviction(self, r): cache = r.get_cache() @@ -259,10 +365,23 @@ def test_cache_lfu_eviction(self, r): assert cache.get(("GET", "foo")) == b"bar" assert cache.get(("GET", "foo2")) is None - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_ignore_not_allowed_command(self, r): cache = r.get_cache() @@ -272,10 +391,23 @@ def test_cache_ignore_not_allowed_command(self, r): assert r.hrandfield("foo") == b"bar" assert cache.get(("HRANDFIELD", "foo")) is None - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_invalidate_all_related_responses(self, r): cache = r.get_cache() @@ -299,10 +431,23 @@ def test_cache_invalidate_all_related_responses(self, r): assert cache.get(("MGET", "foo", "bar")) is None assert cache.get(("GET", "foo")) == b"baz" - @pytest.mark.parametrize("r", [ - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": True}, - {"cache": TTLCache(128, 300), "use_cache": True, "single_connection_client": False}, - ], ids=["single", "pool"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": True, + }, + { + "cache": TTLCache(128, 300), + "use_cache": True, + "single_connection_client": False, + }, + ], + ids=["single", "pool"], + indirect=True, + ) @pytest.mark.onlynoncluster def test_cache_flushed_on_server_flush(self, r): cache = r.get_cache() @@ -329,7 +474,9 @@ def test_cache_flushed_on_server_flush(self, r): @pytest.mark.onlycluster @skip_if_resp_version(2) class TestClusterCache: - @pytest.mark.parametrize("r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True + ) def test_get_from_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add key to redis @@ -345,12 +492,34 @@ def test_get_from_cache(self, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, - ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_size": 128, + "cache_ttl": 300, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LRU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.RANDOM, + "cache_size": 128, + }, + ], + ids=["TTL", "LRU", "LFU", "RANDOM"], + indirect=True, + ) def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() @@ -369,7 +538,9 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + ) @pytest.mark.onlycluster def test_get_from_cache_multithreaded(self, r): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -407,7 +578,9 @@ def test_get_from_cache_multithreaded(self, r): assert cache.get(("GET", "foo")) == b"baz" assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + ) @pytest.mark.onlycluster def test_health_check_invalidate_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -424,7 +597,9 @@ def test_health_check_invalidate_cache(self, r, r2): # Make sure that value was invalidated assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + ) @pytest.mark.onlycluster def test_health_check_invalidate_cache_multithreaded(self, r, r2): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -448,7 +623,9 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2): assert cache.get(("GET", "foo")) is None assert cache.get(("GET", "bar")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + ) @pytest.mark.onlycluster def test_cache_clears_on_disconnect(self, r, r2): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -459,7 +636,9 @@ def test_cache_clears_on_disconnect(self, r, r2): # get key from local cache assert cache.get(("GET", "foo")) == b"bar" # Force disconnection - r.nodes_manager.get_node_from_slot(10).redis_connection.connection_pool.get_connection("_").disconnect() + r.nodes_manager.get_node_from_slot( + 10 + ).redis_connection.connection_pool.get_connection("_").disconnect() # Make sure cache is empty assert cache.currsize == 0 @@ -489,7 +668,9 @@ def test_cache_lru_eviction(self, r): # the first key is not in the local cache anymore assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize("r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True) + @pytest.mark.parametrize( + "r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True + ) @pytest.mark.onlycluster def test_cache_ttl(self, r): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -594,13 +775,20 @@ def test_cache_flushed_on_server_flush(self, r, cache): assert r.get("foo") is None assert cache.currsize == 0 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) class TestSentinelCache: @pytest.mark.parametrize( "sentinel_setup", - [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + [ + { + "cache": LRUCache(maxsize=128), + "use_cache": True, + "force_master_ip": "localhost", + } + ], indirect=True, ) @pytest.mark.onlynoncluster @@ -618,12 +806,34 @@ def test_get_from_cache(self, master): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128}, - {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128}, - ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_size": 128, + "cache_ttl": 300, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LRU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 128, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.RANDOM, + "cache_size": 128, + }, + ], + ids=["TTL", "LRU", "LFU", "RANDOM"], + indirect=True, + ) def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] cache = r.get_cache() @@ -644,7 +854,13 @@ def test_get_from_custom_cache(self, request, r, r2): @pytest.mark.parametrize( "sentinel_setup", - [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + [ + { + "cache": LRUCache(maxsize=128), + "use_cache": True, + "force_master_ip": "localhost", + } + ], indirect=True, ) @pytest.mark.onlynoncluster @@ -689,7 +905,13 @@ def test_get_from_cache_multithreaded(self, master): @pytest.mark.parametrize( "sentinel_setup", - [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + [ + { + "cache": LRUCache(maxsize=128), + "use_cache": True, + "force_master_ip": "localhost", + } + ], indirect=True, ) @pytest.mark.onlynoncluster @@ -710,7 +932,13 @@ def test_health_check_invalidate_cache(self, master, cache): @pytest.mark.parametrize( "sentinel_setup", - [{"cache": LRUCache(maxsize=128), "use_cache": True, "force_master_ip": "localhost"}], + [ + { + "cache": LRUCache(maxsize=128), + "use_cache": True, + "force_master_ip": "localhost", + } + ], indirect=True, ) @pytest.mark.onlynoncluster @@ -723,21 +951,26 @@ def test_cache_clears_on_disconnect(self, master, cache): # get key from local cache assert cache.get(("GET", "foo")) == b"bar" # Force disconnection - master.connection_pool.get_connection('_').disconnect() + master.connection_pool.get_connection("_").disconnect() # Make sure cache is empty assert cache.currsize == 0 + @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) class TestSSLCache: - @pytest.mark.parametrize("r", [ - { - "cache": TTLCache(128, 300), - "use_cache": True, - "ssl": True, - } - ], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], + indirect=True, + ) @pytest.mark.onlynoncluster def test_get_from_cache(self, r, r2, cache): cache = r.get_cache() @@ -754,12 +987,38 @@ def test_get_from_cache(self, r, r2, cache): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [ - {"use_cache": True, "cache_eviction": EvictionPolicy.TTL, "cache_size": 128, "cache_ttl": 300, "ssl": True}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LRU, "cache_size": 128, "ssl": True}, - {"use_cache": True, "cache_eviction": EvictionPolicy.LFU, "cache_size": 128, "ssl": True}, - {"use_cache": True, "cache_eviction": EvictionPolicy.RANDOM, "cache_size": 128, "ssl": True}, - ], ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "use_cache": True, + "cache_eviction": EvictionPolicy.TTL, + "cache_size": 128, + "cache_ttl": 300, + "ssl": True, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LRU, + "cache_size": 128, + "ssl": True, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.LFU, + "cache_size": 128, + "ssl": True, + }, + { + "use_cache": True, + "cache_eviction": EvictionPolicy.RANDOM, + "cache_size": 128, + "ssl": True, + }, + ], + ids=["TTL", "LRU", "LFU", "RANDOM"], + indirect=True, + ) def test_get_from_custom_cache(self, request, r, r2): cache_class = CacheClass[request.node.callspec.id] cache = r.get_cache() @@ -778,13 +1037,17 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize("r", [ - { - "cache": TTLCache(128, 300), - "use_cache": True, - "ssl": True, - } - ], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], + indirect=True, + ) @pytest.mark.onlynoncluster def test_get_from_cache_multithreaded(self, r): cache = r.get_cache() @@ -821,13 +1084,17 @@ def test_get_from_cache_multithreaded(self, r): assert cache.get(("GET", "foo")) == b"baz" assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize("r", [ - { - "cache": TTLCache(128, 300), - "use_cache": True, - "ssl": True, - } - ], indirect=True) + @pytest.mark.parametrize( + "r", + [ + { + "cache": TTLCache(128, 300), + "use_cache": True, + "ssl": True, + } + ], + indirect=True, + ) @pytest.mark.onlynoncluster def test_health_check_invalidate_cache(self, r, r2): cache = r.get_cache() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 8d4f06126f..4ad88e7c08 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -643,10 +643,10 @@ def parse_response_mock_third(connection, *args, **options): mocks["send_command"].assert_has_calls( [ call("READONLY"), - call("GET", "foo", keys=['foo']), + call("GET", "foo", keys=["foo"]), call("READONLY"), - call("GET", "foo", keys=['foo']), - call("GET", "foo", keys=['foo']), + call("GET", "foo", keys=["foo"]), + call("GET", "foo", keys=["foo"]), ] ) From 726803a04801a70cb514850047e380a01f4a02e6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:30:14 +0300 Subject: [PATCH 29/78] Fixed characters per line restriction --- redis/asyncio/client.py | 1 - redis/connection.py | 25 ++++++++++++++++--------- tests/conftest.py | 6 ++++-- tests/test_cache.py | 9 ++++++--- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index b2ad2e2db8..931ec0effd 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -27,7 +27,6 @@ ) from cachetools import Cache - from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, diff --git a/redis/connection.py b/redis/connection.py index 15f7a8d2e2..b00def811e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,12 +8,12 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import time, sleep +from time import time from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse from apscheduler.schedulers.background import BackgroundScheduler -from cachetools import TTLCache, Cache, LRUCache +from cachetools import Cache, LRUCache from cachetools.keys import hashkey from redis.cache import CacheConfiguration, CacheFactory @@ -773,13 +773,15 @@ def check_health(self): self._conn.check_health() def send_packed_command(self, command, check_health=True): - # TODO: Investigate if it's possible to unpack command or extract keys from packed command + # TODO: Investigate if it's possible to unpack command + # or extract keys from packed command self._conn.send_packed_command(command) def send_command(self, *args, **kwargs): self._process_pending_invalidations() - # If command is write command or not allowed to cache, transfer control to the actual connection. + # If command is write command or not allowed to cache + # transfer control to the actual connection. if not self._conf.is_allowed_to_cache(args[0]): self._current_command_hash = None self._current_command_keys = None @@ -797,14 +799,17 @@ def send_command(self, *args, **kwargs): raise TypeError("Cache keys must be a list.") with self._cache_lock: - # If current command reply already cached prevent sending data over socket. + # If current command reply already cached + # prevent sending data over socket. if self._cache.get(self._current_command_hash): return - # Set temporary entry as a status to prevent race condition from another connection. + # Set temporary entry as a status to prevent + # race condition from another connection. self._cache[self._current_command_hash] = "caching-in-progress" - # Send command over socket only if it's allowed read-only command that not yet cached. + # Send command over socket only if it's allowed + # read-only command that not yet cached. self._conn.send_command(*args, **kwargs) def can_read(self, timeout=0): @@ -836,7 +841,8 @@ def read_response( elif self._current_command_hash is None: return response - # Create separate mapping for keys or add current response to associated keys. + # Create separate mapping for keys + # or add current response to associated keys. for key in self._current_command_keys: if key in self._keys_mapping: if self._current_command_hash not in self._keys_mapping[key]: @@ -846,7 +852,8 @@ def read_response( cache_entry = self._cache.get(self._current_command_hash, None) - # Cache only responses that still valid and wasn't invalidated by another connection in meantime. + # Cache only responses that still valid + # and wasn't invalidated by another connection in meantime. if cache_entry is not None: self._cache[self._current_command_hash] = response diff --git a/tests/conftest.py b/tests/conftest.py index 406f13bf67..bcb1543428 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -330,8 +330,10 @@ def _get_client( connection_class = SSLConnection kwargs["ssl_certfile"] = get_ssl_filename("client-cert.pem") kwargs["ssl_keyfile"] = get_ssl_filename("client-key.pem") - # When you try to assign "required" as single string, it assigns tuple instead of string. - # Probably some reserved keyword, I can't explain how does it work -_- + # When you try to assign "required" as single string + # it assigns tuple instead of string. + # Probably some reserved keyword + # I can't explain how does it work -_- kwargs["ssl_cert_reqs"] = "require" + "d" kwargs["ssl_ca_certs"] = get_ssl_filename("ca-cert.pem") kwargs["port"] = 6666 diff --git a/tests/test_cache.py b/tests/test_cache.py index 77e7d8a579..e17eb31919 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -425,7 +425,8 @@ def test_cache_invalidate_all_related_responses(self, r): res.append(b"baz") assert another_res != res - # Invalidate one of the keys and make sure that all associated cached entries was removed + # Invalidate one of the keys and make sure that + # all associated cached entries was removed assert r.set("foo", "baz") assert r.get("foo") == b"baz" assert cache.get(("MGET", "foo", "bar")) is None @@ -743,7 +744,8 @@ def test_cache_invalidate_all_related_responses(self, r, cache): assert r.mget("foo{slot}", "bar{slot}") == [b"bar", b"foo"] assert cache.get(("MGET", "foo{slot}", "bar{slot}")) == [b"bar", b"foo"] - # Invalidate one of the keys and make sure that all associated cached entries was removed + # Invalidate one of the keys and make sure + # that all associated cached entries was removed assert r.set("foo{slot}", "baz") assert r.get("foo{slot}") == b"baz" assert cache.get(("MGET", "foo{slot}", "bar{slot}")) is None @@ -1133,7 +1135,8 @@ def test_cache_invalidate_all_related_responses(self, r): assert r.mget("foo", "bar") == [b"bar", b"foo"] assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"] - # Invalidate one of the keys and make sure that all associated cached entries was removed + # Invalidate one of the keys and make sure + # that all associated cached entries was removed assert r.set("foo", "baz") assert r.get("foo") == b"baz" assert cache.get(("MGET", "foo", "bar")) is None From bed4d73e2d20c44d568636ec411e882e02b1c567 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:33:40 +0300 Subject: [PATCH 30/78] Fixed line length --- redis/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index b00def811e..8f2af280c9 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -889,7 +889,8 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[str]]]]) for key in data[1]: normalized_key = ensure_string(key) if normalized_key in self._keys_mapping: - # Make sure that all command responses associated with this key will be deleted + # Make sure that all command responses + # associated with this key will be deleted for cache_key in self._keys_mapping[normalized_key]: self._cache.pop(cache_key) # Removes key from mapping cache From a78690a8e3bac5ff5969d119c71b0723da9485e8 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:39:09 +0300 Subject: [PATCH 31/78] Removed blank lines in imports --- redis/cache.py | 4 ---- redis/client.py | 2 -- redis/cluster.py | 2 -- tests/conftest.py | 1 - tests/test_cache.py | 2 -- 5 files changed, 11 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 78d8409c44..1fa7f63eaa 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,11 +1,7 @@ from abc import ABC, abstractmethod -from typing import TypeVar from enum import Enum - from cachetools import LRUCache, LFUCache, RRCache, Cache, TTLCache -T = TypeVar("T") - class EvictionPolicy(Enum): LRU = "LRU" diff --git a/redis/client.py b/redis/client.py index 3bdfa808a4..b37f749392 100755 --- a/redis/client.py +++ b/redis/client.py @@ -5,9 +5,7 @@ import warnings from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union - from cachetools import Cache - from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, diff --git a/redis/cluster.py b/redis/cluster.py index ace81b6679..aef3aadeb2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -5,9 +5,7 @@ import time from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union - from cachetools import Cache - from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff diff --git a/tests/conftest.py b/tests/conftest.py index bcb1543428..6f8345b290 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse - import pytest import redis from packaging.version import Version diff --git a/tests/test_cache.py b/tests/test_cache.py index e17eb31919..ba4ab86821 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,9 +1,7 @@ import threading import time - import pytest from cachetools import TTLCache, LRUCache, LFUCache - import redis from redis.cache import EvictionPolicy, CacheClass from redis.utils import HIREDIS_AVAILABLE From c68524852854b71f32b226a6aa1905bda12a7f59 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 20 Aug 2024 14:42:11 +0300 Subject: [PATCH 32/78] Fixed imports codestyle --- redis/cache.py | 3 ++- redis/client.py | 1 + redis/cluster.py | 1 + redis/retry.py | 2 +- tests/conftest.py | 3 ++- tests/test_cache.py | 5 +++-- 6 files changed, 10 insertions(+), 5 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 1fa7f63eaa..8f51803f1f 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum -from cachetools import LRUCache, LFUCache, RRCache, Cache, TTLCache + +from cachetools import Cache, LFUCache, LRUCache, RRCache, TTLCache class EvictionPolicy(Enum): diff --git a/redis/client.py b/redis/client.py index b37f749392..8626748bff 100755 --- a/redis/client.py +++ b/redis/client.py @@ -5,6 +5,7 @@ import warnings from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union + from cachetools import Cache from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( diff --git a/redis/cluster.py b/redis/cluster.py index aef3aadeb2..4ff28e0104 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -5,6 +5,7 @@ import time from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union + from cachetools import Cache from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan diff --git a/redis/retry.py b/redis/retry.py index 7fbc4039e2..218159f861 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -6,10 +6,10 @@ Any, Callable, Iterable, + Optional, Tuple, Type, TypeVar, - Optional, ) from redis.exceptions import ConnectionError, TimeoutError diff --git a/tests/conftest.py b/tests/conftest.py index 6f8345b290..f2da666b27 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,12 +5,13 @@ from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse + import pytest import redis from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff -from redis.connection import Connection, parse_url, SSLConnection +from redis.connection import Connection, SSLConnection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_ssl_filename diff --git a/tests/test_cache.py b/tests/test_cache.py index ba4ab86821..c63f5bd2b9 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,9 +1,10 @@ import threading import time + import pytest -from cachetools import TTLCache, LRUCache, LFUCache import redis -from redis.cache import EvictionPolicy, CacheClass +from cachetools import LFUCache, LRUCache, TTLCache +from redis.cache import CacheClass, EvictionPolicy from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version From 65bd5af60ec76d8ba240ff44e519c355118131f7 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 21 Aug 2024 12:27:25 +0300 Subject: [PATCH 33/78] Added CacheInterface abstraction --- redis/cache.py | 98 +++++++++++++++++++++++++++++++++++++++---- redis/client.py | 6 +-- redis/cluster.py | 7 +++- redis/connection.py | 36 ++++++++++------ tests/test_cache.py | 100 +++++++++++++++++++++++++------------------- 5 files changed, 179 insertions(+), 68 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 8f51803f1f..983454b411 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum +from typing import Any, Hashable from cachetools import Cache, LFUCache, LRUCache, RRCache, TTLCache @@ -121,31 +122,112 @@ def is_allowed_to_cache(self, command: str) -> bool: return command in self.DEFAULT_ALLOW_LIST -class CacheClass(Enum): +class EvictionPolicyCacheClass(Enum): LRU = LRUCache LFU = LFUCache RANDOM = RRCache TTL = TTLCache +class CacheClassEvictionPolicy(Enum): + LRUCache = EvictionPolicy.LRU + LFUCache = EvictionPolicy.LFU + RRCache = EvictionPolicy.RANDOM + TTLCache = EvictionPolicy.TTL + + +class CacheInterface(ABC): + + @property + @abstractmethod + def currsize(self) -> float: + pass + + @property + @abstractmethod + def maxsize(self) -> float: + pass + + @property + @abstractmethod + def eviction_policy(self) -> EvictionPolicy: + pass + + @abstractmethod + def get(self, key: Hashable, default: Any = None): + pass + + @abstractmethod + def set(self, key: Hashable, value: Any): + pass + + @abstractmethod + def exists(self, key: Hashable) -> bool: + pass + + @abstractmethod + def remove(self, key: Hashable): + pass + + @abstractmethod + def clear(self): + pass + + class CacheFactoryInterface(ABC): @abstractmethod - def get_cache(self) -> Cache: + def get_cache(self) -> CacheInterface: pass -class CacheFactory(CacheFactoryInterface): +class CacheToolsFactory(CacheFactoryInterface): def __init__(self, conf: CacheConfiguration): self._conf = conf - def get_cache(self) -> Cache: + def get_cache(self) -> CacheInterface: eviction_policy = self._conf.get_eviction_policy() cache_class = self._get_cache_class(eviction_policy).value if eviction_policy == EvictionPolicy.TTL: - return cache_class(self._conf.get_max_size(), self._conf.get_ttl()) + cache_inst = cache_class(self._conf.get_max_size(), self._conf.get_ttl()) + else: + cache_inst = cache_class(self._conf.get_max_size()) + + return CacheToolsAdapter(cache_inst) + + def _get_cache_class( + self, eviction_policy: EvictionPolicy + ) -> EvictionPolicyCacheClass: + return EvictionPolicyCacheClass[eviction_policy.value] + + +class CacheToolsAdapter(CacheInterface): + def __init__(self, cache: Cache): + self._cache = cache + + def get(self, key: Hashable, default: Any = None): + return self._cache.get(key, default) + + def set(self, key: Hashable, value: Any): + self._cache[key] = value + + def exists(self, key: Hashable) -> bool: + return key in self._cache + + def remove(self, key: Hashable): + self._cache.pop(key) + + def clear(self): + self._cache.clear() + + @property + def currsize(self) -> float: + return self._cache.currsize - return cache_class(self._conf.get_max_size()) + @property + def maxsize(self) -> float: + return self._cache.maxsize - def _get_cache_class(self, eviction_policy: EvictionPolicy) -> CacheClass: - return CacheClass[eviction_policy.value] + @property + def eviction_policy(self) -> EvictionPolicy: + return CacheClassEvictionPolicy[self._cache.__class__.__name__].value diff --git a/redis/client.py b/redis/client.py index 8626748bff..c79462a890 100755 --- a/redis/client.py +++ b/redis/client.py @@ -14,7 +14,7 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import EvictionPolicy +from redis.cache import CacheInterface, EvictionPolicy from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -215,7 +215,7 @@ def __init__( credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, use_cache: bool = False, - cache: Optional[Cache] = None, + cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, @@ -603,7 +603,7 @@ def parse_response(self, connection, command_name, **options): return self.response_callbacks[command_name](response, **options) return response - def get_cache(self) -> Optional[Cache]: + def get_cache(self) -> Optional[CacheInterface]: return self.connection_pool.cache diff --git a/redis/cluster.py b/redis/cluster.py index 4ff28e0104..38ad71e243 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -10,7 +10,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import EvictionPolicy +from redis.cache import CacheInterface, EvictionPolicy from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -170,6 +170,9 @@ def parse_cluster_myshardid(resp, **options): "unix_socket_path", "username", "use_cache", + "cache", + "cache_size", + "cache_ttl", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -504,7 +507,7 @@ def __init__( url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, - cache: Optional[Cache] = None, + cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, diff --git a/redis/connection.py b/redis/connection.py index 8f2af280c9..062fa916a9 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -15,7 +15,12 @@ from apscheduler.schedulers.background import BackgroundScheduler from cachetools import Cache, LRUCache from cachetools.keys import hashkey -from redis.cache import CacheConfiguration, CacheFactory +from redis.cache import ( + CacheConfiguration, + CacheFactoryInterface, + CacheInterface, + CacheToolsFactory, +) from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff @@ -728,7 +733,7 @@ class CacheProxyConnection(ConnectionInterface): def __init__( self, conn: ConnectionInterface, - cache: Cache, + cache: CacheInterface, conf: CacheConfiguration, cache_lock: threading.Lock, ): @@ -806,7 +811,7 @@ def send_command(self, *args, **kwargs): # Set temporary entry as a status to prevent # race condition from another connection. - self._cache[self._current_command_hash] = "caching-in-progress" + self._cache.set(self._current_command_hash, "caching-in-progress") # Send command over socket only if it's allowed # read-only command that not yet cached. @@ -821,10 +826,10 @@ def read_response( with self._cache_lock: # Check if command response exists in a cache and it's not in progress. if ( - self._current_command_hash in self._cache - and self._cache[self._current_command_hash] != "caching-in-progress" + self._cache.exists(self._current_command_hash) + and self._cache.get(self._current_command_hash) != "caching-in-progress" ): - return copy.deepcopy(self._cache[self._current_command_hash]) + return copy.deepcopy(self._cache.get(self._current_command_hash)) response = self._conn.read_response( disable_decoding=disable_decoding, @@ -835,7 +840,7 @@ def read_response( with self._cache_lock: # If response is None prevent from caching. if response is None: - self._cache.pop(self._current_command_hash) + self._cache.remove(self._current_command_hash) return response # Prevent not-allowed command from caching. elif self._current_command_hash is None: @@ -855,7 +860,7 @@ def read_response( # Cache only responses that still valid # and wasn't invalidated by another connection in meantime. if cache_entry is not None: - self._cache[self._current_command_hash] = response + self._cache.set(self._current_command_hash, response) return response @@ -892,7 +897,7 @@ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[str]]]]) # Make sure that all command responses # associated with this key will be deleted for cache_key in self._keys_mapping[normalized_key]: - self._cache.pop(cache_key) + self._cache.remove(cache_key) # Removes key from mapping cache self._keys_mapping.pop(normalized_key) @@ -1246,6 +1251,7 @@ def __init__( self, connection_class=Connection, max_connections: Optional[int] = None, + cache_factory: Optional[CacheFactoryInterface] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 @@ -1257,7 +1263,7 @@ def __init__( self.max_connections = max_connections self.cache = None self._cache_conf = None - self._cache_factory = None + self._cache_factory = cache_factory self.cache_lock = None self.scheduler = None @@ -1269,11 +1275,17 @@ def __init__( self._cache_lock = threading.Lock() cache = self.connection_kwargs.get("cache") + if cache is not None: + if not isinstance(cache, CacheInterface): + raise ValueError("Cache must implement CacheInterface") + self.cache = cache else: - cache_factory = CacheFactory(self._cache_conf) - self.cache = cache_factory.get_cache() + if self._cache_factory is not None: + self.cache = self._cache_factory.get_cache() + else: + self.cache = CacheToolsFactory(self._cache_conf).get_cache() self.scheduler = BackgroundScheduler() self.scheduler.add_job( diff --git a/tests/test_cache.py b/tests/test_cache.py index c63f5bd2b9..5ca3355d45 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,7 +4,7 @@ import pytest import redis from cachetools import LFUCache, LRUCache, TTLCache -from redis.cache import CacheClass, EvictionPolicy +from redis.cache import CacheToolsAdapter, EvictionPolicy, EvictionPolicyCacheClass from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version @@ -44,12 +44,12 @@ class TestCache: "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -103,9 +103,9 @@ def test_get_from_given_cache(self, r, r2): ) @pytest.mark.onlynoncluster def test_get_from_custom_cache(self, request, r, r2): - cache_class = CacheClass[request.node.callspec.id] + expected_policy = EvictionPolicy(request.node.callspec.id) cache = r.get_cache() - assert isinstance(cache, cache_class.value) + assert expected_policy == cache.eviction_policy # add key to redis r.set("foo", "bar") @@ -124,12 +124,12 @@ def test_get_from_custom_cache(self, request, r, r2): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -177,7 +177,7 @@ def test_get_from_cache_multithreaded(self, r): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -201,7 +201,9 @@ def test_health_check_invalidate_cache(self, r, r2): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize( - "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlynoncluster def test_health_check_invalidate_cache_multithreaded(self, r, r2): @@ -232,12 +234,12 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -368,12 +370,12 @@ def test_cache_lfu_eviction(self, r): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -394,12 +396,12 @@ def test_cache_ignore_not_allowed_command(self, r): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -435,12 +437,12 @@ def test_cache_invalidate_all_related_responses(self, r): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": True, }, { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "single_connection_client": False, }, @@ -475,7 +477,9 @@ def test_cache_flushed_on_server_flush(self, r): @skip_if_resp_version(2) class TestClusterCache: @pytest.mark.parametrize( - "r", [{"cache": LRUCache(maxsize=128), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], + indirect=True, ) def test_get_from_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() @@ -521,9 +525,9 @@ def test_get_from_cache(self, r, r2): indirect=True, ) def test_get_from_custom_cache(self, request, r, r2): - cache_class = CacheClass[request.node.callspec.id] + expected_policy = EvictionPolicy[request.node.callspec.id] cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() - assert isinstance(cache, cache_class.value) + assert expected_policy == cache.eviction_policy # add key to redis assert r.set("foo", "bar") @@ -539,7 +543,9 @@ def test_get_from_custom_cache(self, request, r, r2): assert cache.get(("GET", "foo")) == b"barbar" @pytest.mark.parametrize( - "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlycluster def test_get_from_cache_multithreaded(self, r): @@ -579,7 +585,9 @@ def test_get_from_cache_multithreaded(self, r): assert cache.get(("GET", "bar")) == b"bar" @pytest.mark.parametrize( - "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlycluster def test_health_check_invalidate_cache(self, r, r2): @@ -598,7 +606,9 @@ def test_health_check_invalidate_cache(self, r, r2): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize( - "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlycluster def test_health_check_invalidate_cache_multithreaded(self, r, r2): @@ -624,7 +634,9 @@ def test_health_check_invalidate_cache_multithreaded(self, r, r2): assert cache.get(("GET", "bar")) is None @pytest.mark.parametrize( - "r", [{"cache": TTLCache(128, 300), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlycluster def test_cache_clears_on_disconnect(self, r, r2): @@ -644,7 +656,7 @@ def test_cache_clears_on_disconnect(self, r, r2): @pytest.mark.parametrize( "r", - [{"cache": LRUCache(3), "use_cache": True}], + [{"cache": CacheToolsAdapter(LRUCache(3)), "use_cache": True}], indirect=True, ) @pytest.mark.onlycluster @@ -669,7 +681,9 @@ def test_cache_lru_eviction(self, r): assert cache.get(("GET", "foo")) is None @pytest.mark.parametrize( - "r", [{"cache": TTLCache(maxsize=128, ttl=1), "use_cache": True}], indirect=True + "r", + [{"cache": CacheToolsAdapter(TTLCache(maxsize=128, ttl=1)), "use_cache": True}], + indirect=True, ) @pytest.mark.onlycluster def test_cache_ttl(self, r): @@ -687,7 +701,7 @@ def test_cache_ttl(self, r): @pytest.mark.parametrize( "r", - [{"cache": LFUCache(3), "use_cache": True}], + [{"cache": CacheToolsAdapter(LFUCache(3)), "use_cache": True}], indirect=True, ) @pytest.mark.onlycluster @@ -715,7 +729,7 @@ def test_cache_lfu_eviction(self, r): @pytest.mark.parametrize( "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], + [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], indirect=True, ) @pytest.mark.onlycluster @@ -729,7 +743,7 @@ def test_cache_ignore_not_allowed_command(self, r): @pytest.mark.parametrize( "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], + [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], indirect=True, ) @pytest.mark.onlycluster @@ -752,7 +766,7 @@ def test_cache_invalidate_all_related_responses(self, r, cache): @pytest.mark.parametrize( "r", - [{"cache": LRUCache(maxsize=128), "use_cache": True}], + [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], indirect=True, ) @pytest.mark.onlycluster @@ -785,7 +799,7 @@ class TestSentinelCache: "sentinel_setup", [ { - "cache": LRUCache(maxsize=128), + "cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True, "force_master_ip": "localhost", } @@ -836,9 +850,9 @@ def test_get_from_cache(self, master): indirect=True, ) def test_get_from_custom_cache(self, request, r, r2): - cache_class = CacheClass[request.node.callspec.id] + expected_policy = EvictionPolicy[request.node.callspec.id] cache = r.get_cache() - assert isinstance(cache, cache_class.value) + assert expected_policy == cache.eviction_policy # add key to redis r.set("foo", "bar") @@ -857,7 +871,7 @@ def test_get_from_custom_cache(self, request, r, r2): "sentinel_setup", [ { - "cache": LRUCache(maxsize=128), + "cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True, "force_master_ip": "localhost", } @@ -908,7 +922,7 @@ def test_get_from_cache_multithreaded(self, master): "sentinel_setup", [ { - "cache": LRUCache(maxsize=128), + "cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True, "force_master_ip": "localhost", } @@ -935,7 +949,7 @@ def test_health_check_invalidate_cache(self, master, cache): "sentinel_setup", [ { - "cache": LRUCache(maxsize=128), + "cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True, "force_master_ip": "localhost", } @@ -959,13 +973,13 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +# @skip_if_resp_version(2) class TestSSLCache: @pytest.mark.parametrize( "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "ssl": True, } @@ -1021,9 +1035,9 @@ def test_get_from_cache(self, r, r2, cache): indirect=True, ) def test_get_from_custom_cache(self, request, r, r2): - cache_class = CacheClass[request.node.callspec.id] + expected_policy = EvictionPolicy[request.node.callspec.id] cache = r.get_cache() - assert isinstance(cache, cache_class.value) + assert expected_policy == cache.eviction_policy # add key to redis r.set("foo", "bar") @@ -1042,7 +1056,7 @@ def test_get_from_custom_cache(self, request, r, r2): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "ssl": True, } @@ -1089,7 +1103,7 @@ def test_get_from_cache_multithreaded(self, r): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "ssl": True, } @@ -1116,7 +1130,7 @@ def test_health_check_invalidate_cache(self, r, r2): "r", [ { - "cache": TTLCache(128, 300), + "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, "ssl": True, } From 9e9b68c6f08f9a58ac599cd1c871b4889b4fbdf3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 21 Aug 2024 12:31:54 +0300 Subject: [PATCH 34/78] Removed redundant references --- redis/asyncio/client.py | 4 ---- redis/client.py | 1 - redis/cluster.py | 3 +-- redis/connection.py | 2 +- tests/test_cache.py | 4 ++-- 5 files changed, 4 insertions(+), 10 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 931ec0effd..696431d4c8 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -26,7 +26,6 @@ cast, ) -from cachetools import Cache from redis._parsers.helpers import ( _RedisCallbacks, _RedisCallbacksRESP2, @@ -652,9 +651,6 @@ async def parse_response( return await retval if inspect.isawaitable(retval) else retval return response - def get_cache(self) -> Optional[Cache]: - return self.connection_pool.cache - StrictRedis = Redis diff --git a/redis/client.py b/redis/client.py index c79462a890..c058374c67 100755 --- a/redis/client.py +++ b/redis/client.py @@ -6,7 +6,6 @@ from itertools import chain from typing import Any, Callable, Dict, List, Optional, Type, Union -from cachetools import Cache from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( _RedisCallbacks, diff --git a/redis/cluster.py b/redis/cluster.py index 38ad71e243..b3e79efa3f 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,7 +6,6 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from cachetools import Cache from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff @@ -1335,7 +1334,7 @@ def __init__( connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, - cache: Optional[Cache] = None, + cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, cache_size: int = 128, cache_ttl: int = 300, diff --git a/redis/connection.py b/redis/connection.py index 062fa916a9..2bbf34175f 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -13,7 +13,7 @@ from urllib.parse import parse_qs, unquote, urlparse from apscheduler.schedulers.background import BackgroundScheduler -from cachetools import Cache, LRUCache +from cachetools import LRUCache from cachetools.keys import hashkey from redis.cache import ( CacheConfiguration, diff --git a/tests/test_cache.py b/tests/test_cache.py index 5ca3355d45..ef45bee018 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,7 +4,7 @@ import pytest import redis from cachetools import LFUCache, LRUCache, TTLCache -from redis.cache import CacheToolsAdapter, EvictionPolicy, EvictionPolicyCacheClass +from redis.cache import CacheToolsAdapter, EvictionPolicy from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version @@ -973,7 +973,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -# @skip_if_resp_version(2) +@skip_if_resp_version(2) class TestSSLCache: @pytest.mark.parametrize( "r", From ba942154c526e503226f4765d51cbc1d55d0206b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 21 Aug 2024 16:34:10 +0300 Subject: [PATCH 35/78] Moved hardcoded values to constants, restricted dependency versions --- dev_requirements.txt | 2 +- redis/connection.py | 20 ++++++++++---------- requirements.txt | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 13373a09d1..e90f5a67b2 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,5 @@ black==24.3.0 -cachetools +cachetools>=5.5.0 apscheduler click==8.0.4 flake8-isort diff --git a/redis/connection.py b/redis/connection.py index 2bbf34175f..af854cf391 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -730,12 +730,14 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): + CACHE_DUMMY_STATUS = "caching-in-progress" + KEYS_MAPPING_CACHE_SIZE = 10000 + def __init__( self, conn: ConnectionInterface, cache: CacheInterface, - conf: CacheConfiguration, - cache_lock: threading.Lock, + conf: CacheConfiguration ): self.pid = os.getpid() self._conn = conn @@ -743,12 +745,12 @@ def __init__( self.host = self._conn.host self.port = self._conn.port self._cache = cache - self._cache_lock = cache_lock + self._cache_lock = threading.Lock() self._conf = conf self._current_command_hash = None self._current_command_keys = None self._current_options = None - self._keys_mapping = LRUCache(maxsize=10000) + self._keys_mapping = LRUCache(maxsize=self.KEYS_MAPPING_CACHE_SIZE) self.register_connect_callback(self._enable_tracking_callback) def repr_pieces(self): @@ -772,6 +774,7 @@ def on_connect(self): def disconnect(self, *args): with self._cache_lock: self._cache.clear() + self._keys_mapping.clear() self._conn.disconnect(*args) def check_health(self): @@ -811,7 +814,7 @@ def send_command(self, *args, **kwargs): # Set temporary entry as a status to prevent # race condition from another connection. - self._cache.set(self._current_command_hash, "caching-in-progress") + self._cache.set(self._current_command_hash, self.CACHE_DUMMY_STATUS) # Send command over socket only if it's allowed # read-only command that not yet cached. @@ -827,7 +830,7 @@ def read_response( # Check if command response exists in a cache and it's not in progress. if ( self._cache.exists(self._current_command_hash) - and self._cache.get(self._current_command_hash) != "caching-in-progress" + and self._cache.get(self._current_command_hash) != self.CACHE_DUMMY_STATUS ): return copy.deepcopy(self._cache.get(self._current_command_hash)) @@ -1264,7 +1267,6 @@ def __init__( self.cache = None self._cache_conf = None self._cache_factory = cache_factory - self.cache_lock = None self.scheduler = None if connection_kwargs.get("use_cache"): @@ -1272,7 +1274,6 @@ def __init__( raise RedisError("Client caching is only supported with RESP version 3") self._cache_conf = CacheConfiguration(**self.connection_kwargs) - self._cache_lock = threading.Lock() cache = self.connection_kwargs.get("cache") @@ -1443,8 +1444,7 @@ def make_connection(self) -> "ConnectionInterface": return CacheProxyConnection( self.connection_class(**self.connection_kwargs), self.cache, - self._cache_conf, - self._cache_lock, + self._cache_conf ) return self.connection_class(**self.connection_kwargs) diff --git a/requirements.txt b/requirements.txt index 98c67e1a42..415bbb6ca9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ async-timeout>=4.0.3 -cachetools +cachetools>=5.5.0 apscheduler From 25721676e4a2f32fb1b2fac7a805e951b8ea3e0e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 21 Aug 2024 16:39:51 +0300 Subject: [PATCH 36/78] Changed defaults to correct values --- redis/client.py | 4 ++-- redis/cluster.py | 8 ++++---- redis/connection.py | 10 ++++------ tests/test_cache.py | 2 +- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/redis/client.py b/redis/client.py index c058374c67..a9e5ec6bd9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -216,8 +216,8 @@ def __init__( use_cache: bool = False, cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 128, - cache_ttl: int = 300, + cache_size: int = 10000, + cache_ttl: int = 0, ) -> None: """ Initialize a new Redis client. diff --git a/redis/cluster.py b/redis/cluster.py index b3e79efa3f..3058d36a7b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -508,8 +508,8 @@ def __init__( use_cache: bool = False, cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 128, - cache_ttl: int = 300, + cache_size: int = 10000, + cache_ttl: int = 0, **kwargs, ): """ @@ -1336,8 +1336,8 @@ def __init__( use_cache: bool = False, cache: Optional[CacheInterface] = None, cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 128, - cache_ttl: int = 300, + cache_size: int = 10000, + cache_ttl: int = 0, **kwargs, ): self.nodes_cache = {} diff --git a/redis/connection.py b/redis/connection.py index af854cf391..a7aff77c13 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -734,10 +734,7 @@ class CacheProxyConnection(ConnectionInterface): KEYS_MAPPING_CACHE_SIZE = 10000 def __init__( - self, - conn: ConnectionInterface, - cache: CacheInterface, - conf: CacheConfiguration + self, conn: ConnectionInterface, cache: CacheInterface, conf: CacheConfiguration ): self.pid = os.getpid() self._conn = conn @@ -830,7 +827,8 @@ def read_response( # Check if command response exists in a cache and it's not in progress. if ( self._cache.exists(self._current_command_hash) - and self._cache.get(self._current_command_hash) != self.CACHE_DUMMY_STATUS + and self._cache.get(self._current_command_hash) + != self.CACHE_DUMMY_STATUS ): return copy.deepcopy(self._cache.get(self._current_command_hash)) @@ -1444,7 +1442,7 @@ def make_connection(self) -> "ConnectionInterface": return CacheProxyConnection( self.connection_class(**self.connection_kwargs), self.cache, - self._cache_conf + self._cache_conf, ) return self.connection_class(**self.connection_kwargs) diff --git a/tests/test_cache.py b/tests/test_cache.py index ef45bee018..6af52f4f7d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -38,7 +38,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +# @skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", From 74499a9adad2c3c3c5c7a329badf816e2a4d0fb4 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 23 Aug 2024 17:46:42 +0300 Subject: [PATCH 37/78] Added custom background scheduler, added unit testing --- dev_requirements.txt | 1 - redis/connection.py | 39 +++++++++----- redis/scheduler.py | 61 ++++++++++++++++++++++ requirements.txt | 3 +- tests/conftest.py | 33 +++++++++++- tests/test_cache.py | 29 +++++++++-- tests/test_connection.py | 107 ++++++++++++++++++++++++++++++++++++++- tests/test_scheduler.py | 38 ++++++++++++++ 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 redis/scheduler.py create mode 100644 tests/test_scheduler.py diff --git a/dev_requirements.txt b/dev_requirements.txt index e90f5a67b2..9741dc0555 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,5 @@ black==24.3.0 cachetools>=5.5.0 -apscheduler click==8.0.4 flake8-isort flake8 diff --git a/redis/connection.py b/redis/connection.py index a7aff77c13..8cfe9a046c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,11 +8,10 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import time +from time import time, sleep from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from apscheduler.schedulers.background import BackgroundScheduler from cachetools import LRUCache from cachetools.keys import hashkey from redis.cache import ( @@ -21,6 +20,7 @@ CacheInterface, CacheToolsFactory, ) +from . import scheduler from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff @@ -36,6 +36,7 @@ TimeoutError, ) from .retry import Retry +from .scheduler import Scheduler from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, @@ -1265,7 +1266,9 @@ def __init__( self.cache = None self._cache_conf = None self._cache_factory = cache_factory - self.scheduler = None + self._scheduler = None + self._hc_cancel_event = None + self._hc_thread = None if connection_kwargs.get("use_cache"): if connection_kwargs.get("protocol") not in [3, "3"]: @@ -1286,14 +1289,7 @@ def __init__( else: self.cache = CacheToolsFactory(self._cache_conf).get_cache() - self.scheduler = BackgroundScheduler() - self.scheduler.add_job( - self._perform_health_check, - "interval", - seconds=2, - id="cache_health_check", - ) - self.scheduler.start() + self._scheduler = Scheduler() connection_kwargs.pop("use_cache", None) connection_kwargs.pop("cache_eviction", None) @@ -1312,6 +1308,16 @@ def __init__( self._fork_lock = threading.Lock() self.reset() + # Run scheduled healthcheck to avoid stale invalidations in idle connections. + if self.cache is not None and self._scheduler is not None: + self._hc_cancel_event = threading.Event() + self._hc_thread = self._scheduler.run_with_interval( + self._perform_health_check, + 2, + self._hc_cancel_event + ) + + def __repr__(self) -> (str, str): return ( f"<{type(self).__module__}.{type(self).__name__}" @@ -1491,6 +1497,14 @@ def disconnect(self, inuse_connections: bool = True) -> None: for connection in connections: connection.disconnect() + # Send an event to stop scheduled healthcheck execution. + if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): + self._hc_cancel_event.set() + + # Joins healthcheck thread on disconnect. + if self._hc_thread is not None and not self._hc_thread.is_alive(): + self._hc_thread.join() + def close(self) -> None: """Close the pool, disconnecting all connections""" self.disconnect() @@ -1502,13 +1516,14 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def _perform_health_check(self) -> None: + def _perform_health_check(self, done: threading.Event) -> None: self._checkpid() with self._lock: while self._available_connections: conn = self._available_connections.pop() conn.send_command("PING") conn.read_response() + done.set() class BlockingConnectionPool(ConnectionPool): diff --git a/redis/scheduler.py b/redis/scheduler.py new file mode 100644 index 0000000000..981bf48452 --- /dev/null +++ b/redis/scheduler.py @@ -0,0 +1,61 @@ +import threading +import time +from typing import Callable + + +class Scheduler: + + def __init__(self, polling_period: float = 0.1): + """ + :param polling_period: Period between polling operations. + Needs to detect when new job has to be scheduled. + """ + self.polling_period = polling_period + + def run_with_interval( + self, + func: Callable[[threading.Event, ...], None], + interval: float, + cancel: threading.Event, + args: tuple = (), + ) -> threading.Thread: + """ + Run scheduled execution with given interval + in a separate thread until cancel event won't be set. + """ + done = threading.Event() + thread = threading.Thread(target=self._run_timer, args=(func, interval, (done, *args), done, cancel)) + thread.start() + return thread + + def _get_timer( + self, + func: Callable[[threading.Event, ...], None], + interval: float, + args: tuple + ) -> threading.Timer: + timer = threading.Timer(interval=interval, function=func, args=args) + return timer + + def _run_timer( + self, + func: Callable[[threading.Event, ...], None], + interval: float, + args: tuple, + done: threading.Event, + cancel: threading.Event + ): + timer = self._get_timer(func, interval, args) + timer.start() + + while not cancel.is_set(): + if done.is_set(): + done.clear() + timer.join() + timer = self._get_timer(func, interval, args) + timer.start() + else: + time.sleep(self.polling_period) + + timer.cancel() + timer.join() diff --git a/requirements.txt b/requirements.txt index 415bbb6ca9..7ddaf28d97 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ async-timeout>=4.0.3 -cachetools>=5.5.0 -apscheduler +cachetools>=5.5.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f2da666b27..95b60d0d14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import argparse import random +import threading import time from typing import Callable, TypeVar from unittest import mock @@ -7,11 +8,14 @@ from urllib.parse import urlparse import pytest +from _pytest import unittest + import redis from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff -from redis.connection import Connection, SSLConnection, parse_url +from redis.cache import CacheConfiguration, EvictionPolicy, CacheFactoryInterface, CacheInterface +from redis.connection import Connection, SSLConnection, parse_url, ConnectionPool, ConnectionInterface from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_ssl_filename @@ -537,6 +541,33 @@ def master_host(request): return parts.hostname, (parts.port or 6379) +@pytest.fixture() +def cache_conf() -> CacheConfiguration: + return CacheConfiguration( + cache_size=100, + cache_ttl=20, + cache_eviction=EvictionPolicy.TTL + ) + + +@pytest.fixture() +def mock_cache_factory() -> CacheFactoryInterface: + mock_factory = Mock(spec=CacheFactoryInterface) + return mock_factory + + +@pytest.fixture() +def mock_cache() -> CacheInterface: + mock_cache = Mock(spec=CacheInterface) + return mock_cache + + +@pytest.fixture() +def mock_connection() -> ConnectionInterface: + mock_connection = Mock(spec=ConnectionInterface) + return mock_connection + + def wait_for_command(client, monitor, command, key=None): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_cache.py b/tests/test_cache.py index 6af52f4f7d..6b24e7ad40 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,7 +4,7 @@ import pytest import redis from cachetools import LFUCache, LRUCache, TTLCache -from redis.cache import CacheToolsAdapter, EvictionPolicy +from redis.cache import CacheToolsAdapter, EvictionPolicy, CacheConfiguration from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version @@ -185,7 +185,7 @@ def test_get_from_cache_multithreaded(self, r): indirect=True, ) @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache(self, r, r2): + def test_health_check_invalidate_cache(self, r): cache = r.get_cache() # add key to redis r.set("foo", "bar") @@ -194,7 +194,7 @@ def test_health_check_invalidate_cache(self, r, r2): # get key from local cache assert cache.get(("GET", "foo")) == b"bar" # change key in redis (cause invalidation) - r2.set("foo", "barbar") + r.set("foo", "barbar") # Wait for health check time.sleep(2) # Make sure that value was invalidated @@ -1154,3 +1154,26 @@ def test_cache_invalidate_all_related_responses(self, r): assert r.get("foo") == b"baz" assert cache.get(("MGET", "foo", "bar")) is None assert cache.get(("GET", "foo")) == b"baz" + + +class TestUnitCacheConfiguration: + TTL = 20 + MAX_SIZE = 100 + EVICTION_POLICY = EvictionPolicy.TTL + + def test_get_ttl(self, cache_conf: CacheConfiguration): + assert self.TTL == cache_conf.get_ttl() + + def test_get_max_size(self, cache_conf: CacheConfiguration): + assert self.MAX_SIZE == cache_conf.get_max_size() + + def test_get_eviction_policy(self, cache_conf: CacheConfiguration): + assert self.EVICTION_POLICY == cache_conf.get_eviction_policy() + + def test_is_exceeds_max_size(self, cache_conf: CacheConfiguration): + assert not cache_conf.is_exceeds_max_size(self.MAX_SIZE) + assert cache_conf.is_exceeds_max_size(self.MAX_SIZE + 1) + + def test_is_allowed_to_cache(self, cache_conf: CacheConfiguration): + assert cache_conf.is_allowed_to_cache("GET") + assert not cache_conf.is_allowed_to_cache("SET") \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 69275d58c0..f3269869f5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,17 +4,20 @@ from unittest.mock import patch import pytest +from cachetools import TTLCache, LRUCache + import redis from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff +from redis.cache import EvictionPolicy, CacheInterface, CacheToolsAdapter from redis.connection import ( Connection, SSLConnection, UnixDomainSocketConnection, - parse_url, + parse_url, CacheProxyConnection, ) -from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError, RedisError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -346,3 +349,103 @@ def test_unix_socket_connection_failure(): str(e.value) == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." ) + + +class TestUnitConnectionPool: + + @pytest.mark.parametrize("max_conn", (-1, 'str'), ids=("non-positive", "wrong type")) + def test_throws_error_on_incorrect_max_connections(self, max_conn): + with pytest.raises( + ValueError, + match='"max_connections" must be a positive integer' + ): + ConnectionPool( + max_connections=max_conn, + ) + + def test_throws_error_on_cache_enable_in_resp2(self): + with pytest.raises( + RedisError, + match="Client caching is only supported with RESP version 3" + ): + ConnectionPool( + protocol=2, + use_cache=True + ) + + def test_throws_error_on_incorrect_cache_implementation(self): + with pytest.raises( + ValueError, + match="Cache must implement CacheInterface" + ): + ConnectionPool( + protocol=3, + use_cache=True, + cache=TTLCache(100, 20) + ) + + def test_returns_custom_cache_implementation(self, mock_cache): + connection_pool = ConnectionPool( + protocol=3, + use_cache=True, + cache=mock_cache + ) + + assert mock_cache == connection_pool.cache + connection_pool.disconnect() + + def test_creates_cache_with_custom_cache_factory(self, mock_cache_factory, mock_cache): + mock_cache_factory.get_cache.return_value = mock_cache + + connection_pool = ConnectionPool( + protocol=3, + use_cache=True, + cache_size=100, + cache_ttl=20, + cache_eviction=EvictionPolicy.TTL, + cache_factory=mock_cache_factory + ) + + assert connection_pool.cache == mock_cache + connection_pool.disconnect() + + def test_creates_cache_with_given_configuration(self, mock_cache): + connection_pool = ConnectionPool( + protocol=3, + use_cache=True, + cache_size=100, + cache_ttl=20, + cache_eviction=EvictionPolicy.TTL + ) + + assert isinstance(connection_pool.cache, CacheInterface) + assert connection_pool.cache.maxsize == 100 + assert connection_pool.cache.eviction_policy == EvictionPolicy.TTL + connection_pool.disconnect() + + def test_make_connection_proxy_connection_on_given_cache(self): + connection_pool = ConnectionPool( + protocol=3, + use_cache=True + ) + + assert isinstance(connection_pool.make_connection(), CacheProxyConnection) + connection_pool.disconnect() + + +class TestUnitCacheProxyConnection: + def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): + cache = LRUCache(100) + cache['key'] = 'value' + assert cache['key'] == 'value' + + mock_connection.disconnect.return_value = None + mock_connection.retry = 'mock' + mock_connection.host = 'mock' + mock_connection.port = 'mock' + + proxy_connection = CacheProxyConnection(mock_connection, CacheToolsAdapter(cache), cache_conf) + proxy_connection.disconnect() + + assert cache.currsize == 0 + diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000000..465e27f458 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,38 @@ +import threading +import time + +import pytest + +from redis.scheduler import Scheduler + + +class TestScheduler: + @pytest.mark.parametrize( + "polling_period,interval,expected_count", + [ + (0.001, 0.1, (8, 9)), + (0.1, 0.2, (3, 4)), + (0.1, 2, (0, 0)), + ], + ids=[ + 'small polling period (0.001s)', + 'large polling period (0.1s)', + 'interval larger than timeout - no execution', + ] + ) + def test_run_with_interval(self, polling_period, interval, expected_count): + scheduler = Scheduler(polling_period=polling_period) + cancel_event = threading.Event() + counter = 0 + + def callback(done: threading.Event): + nonlocal counter + counter += 1 + done.set() + + scheduler.run_with_interval(func=callback, interval=interval, cancel=cancel_event) + time.sleep(1) + cancel_event.set() + cancel_event.wait() + # Due to flacky nature of test case, provides at least 2 possible values. + assert counter in expected_count From da32b9e637e413c5b942c7ba5271d4c06c577c3e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 23 Aug 2024 17:47:42 +0300 Subject: [PATCH 38/78] Codestyle changes --- redis/connection.py | 9 ++---- redis/scheduler.py | 31 +++++++++--------- tests/conftest.py | 22 +++++++++---- tests/test_cache.py | 4 +-- tests/test_connection.py | 70 ++++++++++++++++------------------------ tests/test_scheduler.py | 13 ++++---- 6 files changed, 70 insertions(+), 79 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 8cfe9a046c..a6129c62f9 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,7 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import time, sleep +from time import sleep, time from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse @@ -20,8 +20,8 @@ CacheInterface, CacheToolsFactory, ) -from . import scheduler +from . import scheduler from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -1312,12 +1312,9 @@ def __init__( if self.cache is not None and self._scheduler is not None: self._hc_cancel_event = threading.Event() self._hc_thread = self._scheduler.run_with_interval( - self._perform_health_check, - 2, - self._hc_cancel_event + self._perform_health_check, 2, self._hc_cancel_event ) - def __repr__(self) -> (str, str): return ( f"<{type(self).__module__}.{type(self).__name__}" diff --git a/redis/scheduler.py b/redis/scheduler.py index 981bf48452..a6a0e7ff12 100644 --- a/redis/scheduler.py +++ b/redis/scheduler.py @@ -13,37 +13,36 @@ def __init__(self, polling_period: float = 0.1): self.polling_period = polling_period def run_with_interval( - self, - func: Callable[[threading.Event, ...], None], - interval: float, - cancel: threading.Event, - args: tuple = (), + self, + func: Callable[[threading.Event, ...], None], + interval: float, + cancel: threading.Event, + args: tuple = (), ) -> threading.Thread: """ Run scheduled execution with given interval in a separate thread until cancel event won't be set. """ done = threading.Event() - thread = threading.Thread(target=self._run_timer, args=(func, interval, (done, *args), done, cancel)) + thread = threading.Thread( + target=self._run_timer, args=(func, interval, (done, *args), done, cancel) + ) thread.start() return thread def _get_timer( - self, - func: Callable[[threading.Event, ...], None], - interval: float, - args: tuple + self, func: Callable[[threading.Event, ...], None], interval: float, args: tuple ) -> threading.Timer: timer = threading.Timer(interval=interval, function=func, args=args) return timer def _run_timer( - self, - func: Callable[[threading.Event, ...], None], - interval: float, - args: tuple, - done: threading.Event, - cancel: threading.Event + self, + func: Callable[[threading.Event, ...], None], + interval: float, + args: tuple, + done: threading.Event, + cancel: threading.Event, ): timer = self._get_timer(func, interval, args) timer.start() diff --git a/tests/conftest.py b/tests/conftest.py index 95b60d0d14..ce36893155 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,14 +8,24 @@ from urllib.parse import urlparse import pytest -from _pytest import unittest - import redis +from _pytest import unittest from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff -from redis.cache import CacheConfiguration, EvictionPolicy, CacheFactoryInterface, CacheInterface -from redis.connection import Connection, SSLConnection, parse_url, ConnectionPool, ConnectionInterface +from redis.cache import ( + CacheConfiguration, + CacheFactoryInterface, + CacheInterface, + EvictionPolicy, +) +from redis.connection import ( + Connection, + ConnectionInterface, + ConnectionPool, + SSLConnection, + parse_url, +) from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_ssl_filename @@ -544,9 +554,7 @@ def master_host(request): @pytest.fixture() def cache_conf() -> CacheConfiguration: return CacheConfiguration( - cache_size=100, - cache_ttl=20, - cache_eviction=EvictionPolicy.TTL + cache_size=100, cache_ttl=20, cache_eviction=EvictionPolicy.TTL ) diff --git a/tests/test_cache.py b/tests/test_cache.py index 6b24e7ad40..0323be0a82 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,7 +4,7 @@ import pytest import redis from cachetools import LFUCache, LRUCache, TTLCache -from redis.cache import CacheToolsAdapter, EvictionPolicy, CacheConfiguration +from redis.cache import CacheConfiguration, CacheToolsAdapter, EvictionPolicy from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version @@ -1176,4 +1176,4 @@ def test_is_exceeds_max_size(self, cache_conf: CacheConfiguration): def test_is_allowed_to_cache(self, cache_conf: CacheConfiguration): assert cache_conf.is_allowed_to_cache("GET") - assert not cache_conf.is_allowed_to_cache("SET") \ No newline at end of file + assert not cache_conf.is_allowed_to_cache("SET") diff --git a/tests/test_connection.py b/tests/test_connection.py index f3269869f5..e34e431cb0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,20 +4,20 @@ from unittest.mock import patch import pytest -from cachetools import TTLCache, LRUCache - import redis +from cachetools import LRUCache, TTLCache from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.cache import EvictionPolicy, CacheInterface, CacheToolsAdapter +from redis.cache import CacheInterface, CacheToolsAdapter, EvictionPolicy from redis.connection import ( + CacheProxyConnection, Connection, SSLConnection, UnixDomainSocketConnection, - parse_url, CacheProxyConnection, + parse_url, ) -from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError, RedisError +from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -353,11 +353,12 @@ def test_unix_socket_connection_failure(): class TestUnitConnectionPool: - @pytest.mark.parametrize("max_conn", (-1, 'str'), ids=("non-positive", "wrong type")) + @pytest.mark.parametrize( + "max_conn", (-1, "str"), ids=("non-positive", "wrong type") + ) def test_throws_error_on_incorrect_max_connections(self, max_conn): with pytest.raises( - ValueError, - match='"max_connections" must be a positive integer' + ValueError, match='"max_connections" must be a positive integer' ): ConnectionPool( max_connections=max_conn, @@ -365,36 +366,23 @@ def test_throws_error_on_incorrect_max_connections(self, max_conn): def test_throws_error_on_cache_enable_in_resp2(self): with pytest.raises( - RedisError, - match="Client caching is only supported with RESP version 3" + RedisError, match="Client caching is only supported with RESP version 3" ): - ConnectionPool( - protocol=2, - use_cache=True - ) + ConnectionPool(protocol=2, use_cache=True) def test_throws_error_on_incorrect_cache_implementation(self): - with pytest.raises( - ValueError, - match="Cache must implement CacheInterface" - ): - ConnectionPool( - protocol=3, - use_cache=True, - cache=TTLCache(100, 20) - ) + with pytest.raises(ValueError, match="Cache must implement CacheInterface"): + ConnectionPool(protocol=3, use_cache=True, cache=TTLCache(100, 20)) def test_returns_custom_cache_implementation(self, mock_cache): - connection_pool = ConnectionPool( - protocol=3, - use_cache=True, - cache=mock_cache - ) + connection_pool = ConnectionPool(protocol=3, use_cache=True, cache=mock_cache) assert mock_cache == connection_pool.cache connection_pool.disconnect() - def test_creates_cache_with_custom_cache_factory(self, mock_cache_factory, mock_cache): + def test_creates_cache_with_custom_cache_factory( + self, mock_cache_factory, mock_cache + ): mock_cache_factory.get_cache.return_value = mock_cache connection_pool = ConnectionPool( @@ -403,7 +391,7 @@ def test_creates_cache_with_custom_cache_factory(self, mock_cache_factory, mock_ cache_size=100, cache_ttl=20, cache_eviction=EvictionPolicy.TTL, - cache_factory=mock_cache_factory + cache_factory=mock_cache_factory, ) assert connection_pool.cache == mock_cache @@ -415,7 +403,7 @@ def test_creates_cache_with_given_configuration(self, mock_cache): use_cache=True, cache_size=100, cache_ttl=20, - cache_eviction=EvictionPolicy.TTL + cache_eviction=EvictionPolicy.TTL, ) assert isinstance(connection_pool.cache, CacheInterface) @@ -424,10 +412,7 @@ def test_creates_cache_with_given_configuration(self, mock_cache): connection_pool.disconnect() def test_make_connection_proxy_connection_on_given_cache(self): - connection_pool = ConnectionPool( - protocol=3, - use_cache=True - ) + connection_pool = ConnectionPool(protocol=3, use_cache=True) assert isinstance(connection_pool.make_connection(), CacheProxyConnection) connection_pool.disconnect() @@ -436,16 +421,17 @@ def test_make_connection_proxy_connection_on_given_cache(self): class TestUnitCacheProxyConnection: def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): cache = LRUCache(100) - cache['key'] = 'value' - assert cache['key'] == 'value' + cache["key"] = "value" + assert cache["key"] == "value" mock_connection.disconnect.return_value = None - mock_connection.retry = 'mock' - mock_connection.host = 'mock' - mock_connection.port = 'mock' + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" - proxy_connection = CacheProxyConnection(mock_connection, CacheToolsAdapter(cache), cache_conf) + proxy_connection = CacheProxyConnection( + mock_connection, CacheToolsAdapter(cache), cache_conf + ) proxy_connection.disconnect() assert cache.currsize == 0 - diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 465e27f458..8ccb2125c4 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2,7 +2,6 @@ import time import pytest - from redis.scheduler import Scheduler @@ -15,10 +14,10 @@ class TestScheduler: (0.1, 2, (0, 0)), ], ids=[ - 'small polling period (0.001s)', - 'large polling period (0.1s)', - 'interval larger than timeout - no execution', - ] + "small polling period (0.001s)", + "large polling period (0.1s)", + "interval larger than timeout - no execution", + ], ) def test_run_with_interval(self, polling_period, interval, expected_count): scheduler = Scheduler(polling_period=polling_period) @@ -30,7 +29,9 @@ def callback(done: threading.Event): counter += 1 done.set() - scheduler.run_with_interval(func=callback, interval=interval, cancel=cancel_event) + scheduler.run_with_interval( + func=callback, interval=interval, cancel=cancel_event + ) time.sleep(1) cancel_event.set() cancel_event.wait() From 36de29b4242e7213508301f6c4fbe0299f3ebbdf Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 09:37:55 +0300 Subject: [PATCH 39/78] Updated RESP2 restriction --- tests/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 0323be0a82..beeb87867e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -38,7 +38,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -# @skip_if_resp_version(2) +@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", From b0236191b86bebde9708d04b53cd77fd70c6bfb0 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 09:43:03 +0300 Subject: [PATCH 40/78] Cahnged typing to more generic --- redis/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/scheduler.py b/redis/scheduler.py index a6a0e7ff12..9bcfc740a0 100644 --- a/redis/scheduler.py +++ b/redis/scheduler.py @@ -14,7 +14,7 @@ def __init__(self, polling_period: float = 0.1): def run_with_interval( self, - func: Callable[[threading.Event, ...], None], + func: Callable, interval: float, cancel: threading.Event, args: tuple = (), @@ -31,14 +31,14 @@ def run_with_interval( return thread def _get_timer( - self, func: Callable[[threading.Event, ...], None], interval: float, args: tuple + self, func: Callable, interval: float, args: tuple ) -> threading.Timer: timer = threading.Timer(interval=interval, function=func, args=args) return timer def _run_timer( self, - func: Callable[[threading.Event, ...], None], + func: Callable, interval: float, args: tuple, done: threading.Event, From 9fe36f8b0a846752b9bed5690b852ee69cf9c07c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 10:13:17 +0300 Subject: [PATCH 41/78] Restrict pytest-asyncio version to 0.23 --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 9741dc0555..4304facab1 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,7 +8,7 @@ invoke==2.2.0 mock packaging>=20.4 pytest -pytest-asyncio +pytest-asyncio>=0.23.0 pytest-cov pytest-profiling pytest-timeout From 8a2a01e4d0f118cb487f8a6c9df809e5b72b1e5d Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 10:15:55 +0300 Subject: [PATCH 42/78] Added upper version limit --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 4304facab1..3bd0be2763 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,7 +8,7 @@ invoke==2.2.0 mock packaging>=20.4 pytest -pytest-asyncio>=0.23.0 +pytest-asyncio>=0.23.0,<0.24.0 pytest-cov pytest-profiling pytest-timeout From 7b503006a8ffd5370b3b0058280d48461e9baf62 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 10:26:35 +0300 Subject: [PATCH 43/78] Removed usntable multithreaded tests --- tests/test_cache.py | 60 +-------------------------------------------- 1 file changed, 1 insertion(+), 59 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index beeb87867e..0ee8053b0d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -38,7 +38,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +#@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", @@ -200,36 +200,6 @@ def test_health_check_invalidate_cache(self, r): # Make sure that value was invalidated assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize( - "r", - [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache_multithreaded(self, r, r2): - cache = r.get_cache() - # Running commands over two threads - threading.Thread(target=r.set("foo", "bar")).start() - threading.Thread(target=r.set("bar", "foo")).start() - # Wait for command execution to be finished - time.sleep(0.1) - # get keys from server - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - # Wait for command execution to be finished - time.sleep(0.1) - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - # change key in redis (cause invalidation) - threading.Thread(target=r2.set("foo", "baz")).start() - threading.Thread(target=r2.set("bar", "bar")).start() - # Wait for health check - time.sleep(2) - # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None - assert cache.get(("GET", "bar")) is None - @pytest.mark.parametrize( "r", [ @@ -605,34 +575,6 @@ def test_health_check_invalidate_cache(self, r, r2): # Make sure that value was invalidated assert cache.get(("GET", "foo")) is None - @pytest.mark.parametrize( - "r", - [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], - indirect=True, - ) - @pytest.mark.onlycluster - def test_health_check_invalidate_cache_multithreaded(self, r, r2): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() - # Running commands over two threads - threading.Thread(target=r.set("foo", "bar")).start() - threading.Thread(target=r.set("bar", "foo")).start() - # Wait for command execution to be finished - time.sleep(0.1) - # get keys from server - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - # change key in redis (cause invalidation) - threading.Thread(target=r.set("foo", "baz")).start() - threading.Thread(target=r.set("bar", "bar")).start() - # Wait for health check - time.sleep(2) - # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None - assert cache.get(("GET", "bar")) is None - @pytest.mark.parametrize( "r", [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], From 73cb06856f9ed372fa792b85b292bbb3da1500f9 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 10:33:22 +0300 Subject: [PATCH 44/78] Removed more flacky multithreaded tests --- tests/test_cache.py | 195 +------------------------------------------- 1 file changed, 1 insertion(+), 194 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 0ee8053b0d..62dfdaff0f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -38,7 +38,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -#@skip_if_resp_version(2) +# @skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", @@ -120,59 +120,6 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize( - "r", - [ - { - "cache": CacheToolsAdapter(TTLCache(128, 300)), - "use_cache": True, - "single_connection_client": True, - }, - { - "cache": CacheToolsAdapter(TTLCache(128, 300)), - "use_cache": True, - "single_connection_client": False, - }, - ], - ids=["single", "pool"], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, r): - cache = r.get_cache() - # Running commands over two threads - threading.Thread(target=r.set("foo", "bar")).start() - threading.Thread(target=r.set("bar", "foo")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that responses was cached. - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - - threading.Thread(target=r.set("foo", "baz")).start() - threading.Thread(target=r.set("bar", "bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that new values was cached. - assert cache.get(("GET", "foo")) == b"baz" - assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize( "r", [ @@ -512,48 +459,6 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize( - "r", - [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], - indirect=True, - ) - @pytest.mark.onlycluster - def test_get_from_cache_multithreaded(self, r): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() - # Running commands over two threads - threading.Thread(target=r.set("foo", "bar")).start() - threading.Thread(target=r.set("bar", "foo")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that both values was cached. - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - - # Running commands over two threads - threading.Thread(target=r.set("foo", "baz")).start() - threading.Thread(target=r.set("bar", "bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that new values was cached. - assert cache.get(("GET", "foo")) == b"baz" - assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize( "r", [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], @@ -809,57 +714,6 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize( - "sentinel_setup", - [ - { - "cache": CacheToolsAdapter(LRUCache(maxsize=128)), - "use_cache": True, - "force_master_ip": "localhost", - } - ], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, master): - cache = master.get_cache() - - # Running commands over two threads - threading.Thread(target=master.set("foo", "bar")).start() - threading.Thread(target=master.set("bar", "foo")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Running commands over two threads - threading.Thread(target=master.get("foo")).start() - threading.Thread(target=master.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that both values was cached. - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - - # Running commands over two threads - threading.Thread(target=master.set("foo", "baz")).start() - threading.Thread(target=master.set("bar", "bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Running commands over two threads - threading.Thread(target=master.get("foo")).start() - threading.Thread(target=master.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that new values was cached. - assert cache.get(("GET", "foo")) == b"baz" - assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize( "sentinel_setup", [ @@ -994,53 +848,6 @@ def test_get_from_custom_cache(self, request, r, r2): # Make sure that new value was cached assert cache.get(("GET", "foo")) == b"barbar" - @pytest.mark.parametrize( - "r", - [ - { - "cache": CacheToolsAdapter(TTLCache(128, 300)), - "use_cache": True, - "ssl": True, - } - ], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_get_from_cache_multithreaded(self, r): - cache = r.get_cache() - # Running commands over two threads - threading.Thread(target=r.set("foo", "bar")).start() - threading.Thread(target=r.set("bar", "foo")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that responses was cached. - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - - threading.Thread(target=r.set("foo", "baz")).start() - threading.Thread(target=r.set("bar", "bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - threading.Thread(target=r.get("foo")).start() - threading.Thread(target=r.get("bar")).start() - - # Wait for command execution to be finished - time.sleep(0.1) - - # Make sure that new values was cached. - assert cache.get(("GET", "foo")) == b"baz" - assert cache.get(("GET", "bar")) == b"bar" - @pytest.mark.parametrize( "r", [ From ad7e9774004315c8bad4d49d0a437196c4f7e35e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 27 Aug 2024 11:55:57 +0300 Subject: [PATCH 45/78] Fixed issue with Sentinel killing healthcheck thread before execution --- redis/connection.py | 16 +++++++++------- redis/sentinel.py | 1 + tests/test_cache.py | 3 +-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index a6129c62f9..c2c79d38c7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1307,13 +1307,7 @@ def __init__( # release the lock. self._fork_lock = threading.Lock() self.reset() - - # Run scheduled healthcheck to avoid stale invalidations in idle connections. - if self.cache is not None and self._scheduler is not None: - self._hc_cancel_event = threading.Event() - self._hc_thread = self._scheduler.run_with_interval( - self._perform_health_check, 2, self._hc_cancel_event - ) + self.run_scheduled_healthcheck() def __repr__(self) -> (str, str): return ( @@ -1513,6 +1507,14 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry + def run_scheduled_healthcheck(self) -> None: + # Run scheduled healthcheck to avoid stale invalidations in idle connections. + if self.cache is not None and self._scheduler is not None: + self._hc_cancel_event = threading.Event() + self._hc_thread = self._scheduler.run_with_interval( + self._perform_health_check, 2, self._hc_cancel_event + ) + def _perform_health_check(self, done: threading.Event) -> None: self._checkpid() with self._lock: diff --git a/redis/sentinel.py b/redis/sentinel.py index 01e210794c..17cc926a98 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -115,6 +115,7 @@ def get_master_address(self): connection_pool = self.connection_pool_ref() if connection_pool is not None: connection_pool.disconnect(inuse_connections=False) + connection_pool.run_scheduled_healthcheck() return master_address def rotate_slaves(self): diff --git a/tests/test_cache.py b/tests/test_cache.py index 62dfdaff0f..5705094bac 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,4 +1,3 @@ -import threading import time import pytest @@ -38,7 +37,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -# @skip_if_resp_version(2) +@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", From c106873f9f446b0d74ab1b7a13eb11ab7071e1b5 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 29 Aug 2024 14:14:26 +0300 Subject: [PATCH 46/78] Removed cachetools dependency, added custom cache implementation --- dev_requirements.txt | 1 - docs/examples/connection_examples.ipynb | 2 - redis/cache.py | 395 +++++++--- redis/client.py | 10 +- redis/cluster.py | 30 +- redis/commands/core.py | 6 +- redis/connection.py | 138 ++-- requirements.txt | 3 +- tests/conftest.py | 29 +- tests/test_cache.py | 981 +++++++++++++++--------- tests/test_cluster.py | 4 +- tests/test_connection.py | 45 +- 12 files changed, 1027 insertions(+), 617 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 3bd0be2763..37a107d16d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,4 @@ black==24.3.0 -cachetools>=5.5.0 click==8.0.4 flake8-isort flake8 diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index cddded2865..fd60e2a495 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -69,9 +69,7 @@ }, { "cell_type": "markdown", - "execution_count": null, "metadata": {}, - "outputs": [], "source": [ "### By default this library uses the RESP 2 protocol. To enable RESP3, set protocol=3." ] diff --git a/redis/cache.py b/redis/cache.py index 983454b411..8903569f99 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,19 +1,265 @@ +import copy +import time from abc import ABC, abstractmethod +from collections import OrderedDict from enum import Enum -from typing import Any, Hashable +from typing import Any, Collection, Hashable, List, Optional -from cachetools import Cache, LFUCache, LRUCache, RRCache, TTLCache + +class CacheEntryStatus(Enum): + VALID = "VALID" + IN_PROGRESS = "IN_PROGRESS" + + +class EvictionPolicyType(Enum): + time_based = "time_based" + frequency_based = "frequency_based" + + +class CacheKey: + def __init__(self, command: str, redis_keys: tuple[str, ...]): + self.command = command + self.redis_keys = redis_keys + + def get_redis_keys(self) -> tuple[str, ...]: + return self.redis_keys + + def __hash__(self): + return hash((self.command, self.redis_keys)) + + def __eq__(self, other): + return hash(self) == hash(other) + + +class CacheEntry: + def __init__( + self, cache_key: CacheKey, cache_value: bytes, status: CacheEntryStatus + ): + self.cache_key = cache_key + self.cache_value = cache_value + self.status = status + + +class EvictionPolicyInterface(ABC): + @property + @abstractmethod + def cache(self): + pass + + @cache.setter + def cache(self, value): + pass + + @property + @abstractmethod + def type(self) -> EvictionPolicyType: + pass + + @abstractmethod + def evict_next(self) -> CacheKey: + pass + + @abstractmethod + def evict_many(self, count: int) -> List[CacheKey]: + pass + + @abstractmethod + def touch(self, cache_key: CacheKey) -> None: + pass + + +class CacheInterface(ABC): + @abstractmethod + def get_collection(self) -> OrderedDict[CacheKey, CacheEntry]: + pass + + @abstractmethod + def get_eviction_policy(self) -> EvictionPolicyInterface: + pass + + @abstractmethod + def get_max_size(self) -> int: + pass + + @abstractmethod + def get(self, key: CacheKey) -> CacheEntry | None: + pass + + @abstractmethod + def set(self, entry: CacheEntry) -> bool: + pass + + @abstractmethod + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + pass + + @abstractmethod + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + pass + + @abstractmethod + def flush(self) -> int: + pass + + @abstractmethod + def is_cachable(self, key: CacheKey) -> bool: + pass + + +class CacheConfigurationInterface(ABC): + @abstractmethod + def get_cache_class(self): + pass + + @abstractmethod + def get_max_size(self) -> int: + pass + + @abstractmethod + def get_eviction_policy(self): + pass + + @abstractmethod + def is_exceeds_max_size(self, count: int) -> bool: + pass + + @abstractmethod + def is_allowed_to_cache(self, command: str) -> bool: + pass + + +class DefaultCache(CacheInterface): + def __init__( + self, + cache_config: CacheConfigurationInterface, + ) -> None: + self._cache = OrderedDict() + self._cache_config = cache_config + self._eviction_policy = self._cache_config.get_eviction_policy().value() + self._eviction_policy.cache = self + + def get_collection(self) -> OrderedDict[CacheKey, CacheEntry]: + return self._cache + + def get_eviction_policy(self) -> EvictionPolicyInterface: + return self._eviction_policy + + def get_max_size(self) -> int: + return self._cache_config.get_max_size() + + def set(self, entry: CacheEntry) -> bool: + if not self.is_cachable(entry.cache_key): + return False + + self._cache[entry.cache_key] = entry + self._eviction_policy.touch(entry.cache_key) + + if self._cache_config.is_exceeds_max_size(len(self._cache)): + self._eviction_policy.evict_next() + + return True + + def get(self, key: CacheKey) -> CacheEntry | None: + entry = self._cache.get(key, None) + + if entry is None: + return None + + self._eviction_policy.touch(key) + return copy.deepcopy(entry) + + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + response = [] + + for key in cache_keys: + if self.get(key) is not None: + self._cache.pop(key) + response.append(True) + else: + response.append(False) + + return response + + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + response = [] + keys_to_delete = [] + + for redis_key in redis_keys: + redis_key = redis_key.decode() + for cache_key in self._cache: + if redis_key in cache_key.get_redis_keys(): + keys_to_delete.append(cache_key) + response.append(True) + + for key in keys_to_delete: + self._cache.pop(key) + + return response + + def flush(self) -> int: + elem_count = len(self._cache) + self._cache.clear() + return elem_count + + def is_cachable(self, key: CacheKey) -> bool: + return self._cache_config.is_allowed_to_cache(key.command) + + +class LRUPolicy(EvictionPolicyInterface): + def __init__(self): + self.cache = None + + @property + def cache(self): + return self._cache + + @cache.setter + def cache(self, cache: CacheInterface): + self._cache = cache + + @property + def type(self) -> EvictionPolicyType: + return EvictionPolicyType.time_based + + def evict_next(self) -> CacheKey: + self._assert_cache() + popped_entry = self._cache.get_collection().popitem(last=False) + return popped_entry[0] + + def evict_many(self, count: int) -> List[CacheKey]: + self._assert_cache() + if count > len(self._cache.get_collection()): + raise ValueError("Evictions count is above cache size") + + popped_keys = [] + + for _ in range(count): + popped_entry = self._cache.get_collection().popitem(last=False) + popped_keys.append(popped_entry[0]) + + return popped_keys + + def touch(self, cache_key: CacheKey) -> None: + self._assert_cache() + + if self._cache.get_collection().get(cache_key) is None: + raise ValueError(f"Given entry does not belong to the cache") + + self._cache.get_collection().move_to_end(cache_key) + + def _assert_cache(self): + if self.cache is None or not isinstance(self.cache, CacheInterface): + raise ValueError("Eviction policy should be associated with valid cache.") class EvictionPolicy(Enum): - LRU = "LRU" - LFU = "LFU" - RANDOM = "RANDOM" - TTL = "TTL" + LRU = LRUPolicy -class CacheConfiguration: +class CacheConfiguration(CacheConfigurationInterface): + DEFAULT_CACHE_CLASS = DefaultCache DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU + DEFAULT_MAX_SIZE = 10000 DEFAULT_ALLOW_LIST = [ "BITCOUNT", @@ -92,22 +338,18 @@ class CacheConfiguration: "ZUNION", ] - def __init__(self, **kwargs): - self._max_size = kwargs.get("cache_size", None) - self._ttl = kwargs.get("cache_ttl", None) - self._eviction_policy = kwargs.get("cache_eviction", None) - if self._max_size is None: - self._max_size = 10000 - if self._ttl is None: - self._ttl = 0 - if self._eviction_policy is None: - self._eviction_policy = EvictionPolicy.LRU + def __init__( + self, + max_size: int = DEFAULT_MAX_SIZE, + cache_class: Any = DEFAULT_CACHE_CLASS, + eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, + ): + self._cache_class = cache_class + self._max_size = max_size + self._eviction_policy = eviction_policy - if self._eviction_policy not in EvictionPolicy: - raise ValueError(f"Invalid eviction_policy {self._eviction_policy}") - - def get_ttl(self) -> int: - return self._ttl + def get_cache_class(self): + return self._cache_class def get_max_size(self) -> int: return self._max_size @@ -122,112 +364,19 @@ def is_allowed_to_cache(self, command: str) -> bool: return command in self.DEFAULT_ALLOW_LIST -class EvictionPolicyCacheClass(Enum): - LRU = LRUCache - LFU = LFUCache - RANDOM = RRCache - TTL = TTLCache - - -class CacheClassEvictionPolicy(Enum): - LRUCache = EvictionPolicy.LRU - LFUCache = EvictionPolicy.LFU - RRCache = EvictionPolicy.RANDOM - TTLCache = EvictionPolicy.TTL - - -class CacheInterface(ABC): - - @property - @abstractmethod - def currsize(self) -> float: - pass - - @property - @abstractmethod - def maxsize(self) -> float: - pass - - @property - @abstractmethod - def eviction_policy(self) -> EvictionPolicy: - pass - - @abstractmethod - def get(self, key: Hashable, default: Any = None): - pass - - @abstractmethod - def set(self, key: Hashable, value: Any): - pass - - @abstractmethod - def exists(self, key: Hashable) -> bool: - pass - - @abstractmethod - def remove(self, key: Hashable): - pass - - @abstractmethod - def clear(self): - pass - - class CacheFactoryInterface(ABC): @abstractmethod def get_cache(self) -> CacheInterface: pass -class CacheToolsFactory(CacheFactoryInterface): - def __init__(self, conf: CacheConfiguration): - self._conf = conf - - def get_cache(self) -> CacheInterface: - eviction_policy = self._conf.get_eviction_policy() - cache_class = self._get_cache_class(eviction_policy).value - - if eviction_policy == EvictionPolicy.TTL: - cache_inst = cache_class(self._conf.get_max_size(), self._conf.get_ttl()) - else: - cache_inst = cache_class(self._conf.get_max_size()) - - return CacheToolsAdapter(cache_inst) - - def _get_cache_class( - self, eviction_policy: EvictionPolicy - ) -> EvictionPolicyCacheClass: - return EvictionPolicyCacheClass[eviction_policy.value] - - -class CacheToolsAdapter(CacheInterface): - def __init__(self, cache: Cache): - self._cache = cache - - def get(self, key: Hashable, default: Any = None): - return self._cache.get(key, default) +class CacheFactory(CacheFactoryInterface): + def __init__(self, cache_config: Optional[CacheConfiguration] = None): + self._config = cache_config - def set(self, key: Hashable, value: Any): - self._cache[key] = value + if self._config is None: + self._config = CacheConfiguration() - def exists(self, key: Hashable) -> bool: - return key in self._cache - - def remove(self, key: Hashable): - self._cache.pop(key) - - def clear(self): - self._cache.clear() - - @property - def currsize(self) -> float: - return self._cache.currsize - - @property - def maxsize(self) -> float: - return self._cache.maxsize - - @property - def eviction_policy(self) -> EvictionPolicy: - return CacheClassEvictionPolicy[self._cache.__class__.__name__].value + def get_cache(self) -> CacheInterface: + cache_class = self._config.get_cache_class() + return cache_class(cache_config=self._config) diff --git a/redis/client.py b/redis/client.py index a9e5ec6bd9..220fbf238b 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,7 +13,7 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import CacheInterface, EvictionPolicy +from redis.cache import CacheConfiguration, CacheInterface from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -215,9 +215,7 @@ def __init__( protocol: Optional[int] = 2, use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 10000, - cache_ttl: int = 0, + cache_config: Optional[CacheConfiguration] = None, ) -> None: """ Initialize a new Redis client. @@ -315,9 +313,7 @@ def __init__( { "use_cache": use_cache, "cache": cache, - "cache_eviction": cache_eviction, - "cache_size": cache_size, - "cache_ttl": cache_ttl, + "cache_config": cache_config, } ) connection_pool = ConnectionPool(**kwargs) diff --git a/redis/cluster.py b/redis/cluster.py index 3058d36a7b..1635d0f8b7 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import CacheInterface, EvictionPolicy +from redis.cache import CacheConfiguration, CacheInterface from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -170,8 +170,7 @@ def parse_cluster_myshardid(resp, **options): "username", "use_cache", "cache", - "cache_size", - "cache_ttl", + "cache_config", ) KWARGS_DISABLED_KEYS = ("host", "port") @@ -507,9 +506,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 10000, - cache_ttl: int = 0, + cache_config: Optional[CacheConfiguration] = None, **kwargs, ): """ @@ -651,9 +648,7 @@ def __init__( address_remap=address_remap, use_cache=use_cache, cache=cache, - cache_eviction=cache_eviction, - cache_size=cache_size, - cache_ttl=cache_ttl, + cache_config=cache_config, **kwargs, ) @@ -1335,9 +1330,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_eviction: Optional[EvictionPolicy] = None, - cache_size: int = 10000, - cache_ttl: int = 0, + cache_config: Optional[CacheConfiguration] = None, **kwargs, ): self.nodes_cache = {} @@ -1352,9 +1345,7 @@ def __init__( self.address_remap = address_remap self.use_cache = use_cache self.cache = cache - self.cache_eviction = cache_eviction - self.cache_size = cache_size - self.cache_ttl = cache_ttl + self.cache_config = cache_config self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1500,9 +1491,7 @@ def create_redis_node(self, host, port, **kwargs): kwargs.update({"port": port}) kwargs.update({"use_cache": self.use_cache}) kwargs.update({"cache": self.cache}) - kwargs.update({"cache_eviction": self.cache_eviction}) - kwargs.update({"cache_size": self.cache_size}) - kwargs.update({"cache_ttl": self.cache_ttl}) + kwargs.update({"cache_config": self.cache_config}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: r = Redis( @@ -1510,9 +1499,7 @@ def create_redis_node(self, host, port, **kwargs): port=port, use_cache=self.use_cache, cache=self.cache, - cache_eviction=self.cache_eviction, - cache_size=self.cache_size, - cache_ttl=self.cache_ttl, + cache_config=self.cache_config, **kwargs, ) return r @@ -1563,6 +1550,7 @@ def initialize(self): # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) + r.connection_pool.disconnect() except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" diff --git a/redis/commands/core.py b/redis/commands/core.py index b356d101ee..883c7d4dae 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5707,7 +5707,7 @@ def script_exists(self, *args: str) -> ResponseT: """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if - if each already script exists in the cache. + if each already script exists in the cache_data. For more information see https://redis.io/commands/script-exists """ @@ -5721,7 +5721,7 @@ def script_debug(self, *args) -> None: def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None ) -> ResponseT: - """Flush all scripts from the script cache. + """Flush all scripts from the script cache_data. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -5752,7 +5752,7 @@ def script_kill(self) -> ResponseT: def script_load(self, script: ScriptTextT) -> ResponseT: """ - Load a Lua ``script`` into the script cache. Returns the SHA. + Load a Lua ``script`` into the script cache_data. Returns the SHA. For more information see https://redis.io/commands/script-load """ diff --git a/redis/connection.py b/redis/connection.py index c2c79d38c7..95e7cb074d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,20 +8,20 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from time import sleep, time +from time import time from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from cachetools import LRUCache -from cachetools.keys import hashkey from redis.cache import ( CacheConfiguration, + CacheEntry, + CacheEntryStatus, + CacheFactory, CacheFactoryInterface, CacheInterface, - CacheToolsFactory, + CacheKey, ) -from . import scheduler from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider @@ -731,12 +731,9 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): - CACHE_DUMMY_STATUS = "caching-in-progress" - KEYS_MAPPING_CACHE_SIZE = 10000 + DUMMY_CACHE_VALUE = b"foo" - def __init__( - self, conn: ConnectionInterface, cache: CacheInterface, conf: CacheConfiguration - ): + def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() self._conn = conn self.retry = self._conn.retry @@ -744,11 +741,8 @@ def __init__( self.port = self._conn.port self._cache = cache self._cache_lock = threading.Lock() - self._conf = conf - self._current_command_hash = None - self._current_command_keys = None + self._current_command_cache_key = None self._current_options = None - self._keys_mapping = LRUCache(maxsize=self.KEYS_MAPPING_CACHE_SIZE) self.register_connect_callback(self._enable_tracking_callback) def repr_pieces(self): @@ -771,8 +765,7 @@ def on_connect(self): def disconnect(self, *args): with self._cache_lock: - self._cache.clear() - self._keys_mapping.clear() + self._cache.flush() self._conn.disconnect(*args) def check_health(self): @@ -786,33 +779,37 @@ def send_packed_command(self, command, check_health=True): def send_command(self, *args, **kwargs): self._process_pending_invalidations() - # If command is write command or not allowed to cache - # transfer control to the actual connection. - if not self._conf.is_allowed_to_cache(args[0]): - self._current_command_hash = None - self._current_command_keys = None - self._conn.send_command(*args, **kwargs) - return - - # Create hash representation of current executed command. - self._current_command_hash = hashkey(*args) + with self._cache_lock: + # Command is write command or not allowed + # to be cached. + if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): + self._current_command_cache_key = None + self._conn.send_command(*args, **kwargs) + return - # Extract keys from current command. - if kwargs.get("keys"): - self._current_command_keys = kwargs["keys"] + if kwargs.get("keys") is None: + raise ValueError("Cannot create cache key.") - if not isinstance(self._current_command_keys, list): - raise TypeError("Cache keys must be a list.") + # Creates cache key. + self._current_command_cache_key = CacheKey( + command=args[0], redis_keys=tuple(kwargs.get("keys")) + ) with self._cache_lock: # If current command reply already cached # prevent sending data over socket. - if self._cache.get(self._current_command_hash): + if self._cache.get(self._current_command_cache_key): return - # Set temporary entry as a status to prevent + # Set temporary entry value to prevent # race condition from another connection. - self._cache.set(self._current_command_hash, self.CACHE_DUMMY_STATUS) + self._cache.set( + CacheEntry( + cache_key=self._current_command_cache_key, + cache_value=self.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + ) + ) # Send command over socket only if it's allowed # read-only command that not yet cached. @@ -827,11 +824,12 @@ def read_response( with self._cache_lock: # Check if command response exists in a cache and it's not in progress. if ( - self._cache.exists(self._current_command_hash) - and self._cache.get(self._current_command_hash) - != self.CACHE_DUMMY_STATUS + self._current_command_cache_key is not None + and self._cache.get(self._current_command_cache_key) is not None + and self._cache.get(self._current_command_cache_key).status + != CacheEntryStatus.IN_PROGRESS ): - return copy.deepcopy(self._cache.get(self._current_command_hash)) + return self._cache.get(self._current_command_cache_key).cache_value response = self._conn.read_response( disable_decoding=disable_decoding, @@ -840,29 +838,26 @@ def read_response( ) with self._cache_lock: + # Prevent not-allowed command from caching. + if self._current_command_cache_key is None: + return response # If response is None prevent from caching. if response is None: - self._cache.remove(self._current_command_hash) - return response - # Prevent not-allowed command from caching. - elif self._current_command_hash is None: + self._cache.delete_by_cache_keys([self._current_command_cache_key]) return response - # Create separate mapping for keys - # or add current response to associated keys. - for key in self._current_command_keys: - if key in self._keys_mapping: - if self._current_command_hash not in self._keys_mapping[key]: - self._keys_mapping[key].append(self._current_command_hash) - else: - self._keys_mapping[key] = [self._current_command_hash] - - cache_entry = self._cache.get(self._current_command_hash, None) + cache_entry = self._cache.get(self._current_command_cache_key) # Cache only responses that still valid # and wasn't invalidated by another connection in meantime. if cache_entry is not None: - self._cache.set(self._current_command_hash, response) + self._cache.set( + CacheEntry( + cache_key=self._current_command_cache_key, + cache_value=response, + status=CacheEntryStatus.VALID, + ) + ) return response @@ -887,21 +882,13 @@ def _process_pending_invalidations(self): while self.retry.call_with_retry_on_false(lambda: self.can_read()): self._conn.read_response(push_request=True) - def _on_invalidation_callback(self, data: List[Union[str, Optional[List[str]]]]): + def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): with self._cache_lock: # Flush cache when DB flushed on server-side if data[1] is None: - self._cache.clear() + self._cache.flush() else: - for key in data[1]: - normalized_key = ensure_string(key) - if normalized_key in self._keys_mapping: - # Make sure that all command responses - # associated with this key will be deleted - for cache_key in self._keys_mapping[normalized_key]: - self._cache.remove(cache_key) - # Removes key from mapping cache - self._keys_mapping.pop(normalized_key) + self._cache.delete_by_redis_keys(data[1]) class SSLConnection(Connection): @@ -1274,8 +1261,6 @@ def __init__( if connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") - self._cache_conf = CacheConfiguration(**self.connection_kwargs) - cache = self.connection_kwargs.get("cache") if cache is not None: @@ -1287,15 +1272,15 @@ def __init__( if self._cache_factory is not None: self.cache = self._cache_factory.get_cache() else: - self.cache = CacheToolsFactory(self._cache_conf).get_cache() + self.cache = CacheFactory( + self.connection_kwargs.get("cache_config") + ).get_cache() self._scheduler = Scheduler() connection_kwargs.pop("use_cache", None) - connection_kwargs.pop("cache_eviction", None) - connection_kwargs.pop("cache_size", None) - connection_kwargs.pop("cache_ttl", None) connection_kwargs.pop("cache", None) + connection_kwargs.pop("cache_config", None) # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1435,11 +1420,9 @@ def make_connection(self) -> "ConnectionInterface": raise ConnectionError("Too many connections") self._created_connections += 1 - if self.cache is not None and self._cache_conf is not None: + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), - self.cache, - self._cache_conf, + self.connection_class(**self.connection_kwargs), self.cache ) return self.connection_class(**self.connection_kwargs) @@ -1601,12 +1584,9 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None and self._cache_conf is not None: + if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), - self.cache, - self._cache_conf, - self._cache_lock, + self.connection_class(**self.connection_kwargs), self.cache ) else: connection = self.connection_class(**self.connection_kwargs) diff --git a/requirements.txt b/requirements.txt index 7ddaf28d97..622f70b810 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -async-timeout>=4.0.3 -cachetools>=5.5.0 \ No newline at end of file +async-timeout>=4.0.3 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index ce36893155..3b73bedf95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import argparse import random -import threading import time from typing import Callable, TypeVar from unittest import mock @@ -9,7 +8,6 @@ import pytest import redis -from _pytest import unittest from packaging.version import Version from redis import Sentinel from redis.backoff import NoBackoff @@ -17,15 +15,10 @@ CacheConfiguration, CacheFactoryInterface, CacheInterface, + CacheKey, EvictionPolicy, ) -from redis.connection import ( - Connection, - ConnectionInterface, - ConnectionPool, - SSLConnection, - parse_url, -) +from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry from tests.ssl_utils import get_ssl_filename @@ -448,8 +441,7 @@ def sentinel_setup(request): kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} use_cache = request.param.get("use_cache", False) cache = request.param.get("cache", None) - cache_size = request.param.get("cache_size", 128) - cache_ttl = request.param.get("cache_ttl", 300) + cache_config = request.param.get("cache_config", None) force_master_ip = request.param.get("force_master_ip", None) sentinel = Sentinel( sentinel_endpoints, @@ -457,8 +449,7 @@ def sentinel_setup(request): socket_timeout=0.1, use_cache=use_cache, cache=cache, - cache_ttl=cache_ttl, - cache_size=cache_size, + cache_config=cache_config, protocol=3, **kwargs, ) @@ -553,9 +544,7 @@ def master_host(request): @pytest.fixture() def cache_conf() -> CacheConfiguration: - return CacheConfiguration( - cache_size=100, cache_ttl=20, cache_eviction=EvictionPolicy.TTL - ) + return CacheConfiguration(max_size=100, eviction_policy=EvictionPolicy.LRU) @pytest.fixture() @@ -576,6 +565,14 @@ def mock_connection() -> ConnectionInterface: return mock_connection +@pytest.fixture() +def cache_key(request) -> CacheKey: + command = request.param.get("command") + keys = request.param.get("redis_keys") + + return CacheKey(command, keys) + + def wait_for_command(client, monitor, command, key=None): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_cache.py b/tests/test_cache.py index 5705094bac..e4298264af 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,8 +2,16 @@ import pytest import redis -from cachetools import LFUCache, LRUCache, TTLCache -from redis.cache import CacheConfiguration, CacheToolsAdapter, EvictionPolicy +from redis.cache import ( + CacheConfiguration, + CacheEntry, + CacheEntryStatus, + CacheKey, + DefaultCache, + EvictionPolicy, + EvictionPolicyType, + LRUPolicy, +) from redis.utils import HIREDIS_AVAILABLE from tests.conftest import _get_client, skip_if_resp_version @@ -12,9 +20,7 @@ def r(request): use_cache = request.param.get("use_cache", False) cache = request.param.get("cache") - cache_eviction = request.param.get("cache_eviction") - cache_size = request.param.get("cache_size") - cache_ttl = request.param.get("cache_ttl") + cache_config = request.param.get("cache_config") kwargs = request.param.get("kwargs", {}) protocol = request.param.get("protocol", 3) ssl = request.param.get("ssl", False) @@ -27,9 +33,7 @@ def r(request): single_connection_client=single_connection_client, use_cache=use_cache, cache=cache, - cache_eviction=cache_eviction, - cache_size=cache_size, - cache_ttl=cache_ttl, + cache_config=cache_config, **kwargs, ) as client: yield client @@ -43,12 +47,12 @@ class TestCache: "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), + "cache": DefaultCache(CacheConfiguration(max_size=5)), "use_cache": True, "single_connection_client": True, }, { - "cache": CacheToolsAdapter(TTLCache(128, 300)), + "cache": DefaultCache(CacheConfiguration(max_size=5)), "use_cache": True, "single_connection_client": False, }, @@ -64,67 +68,68 @@ def test_get_from_given_cache(self, r, r2): # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_size": 128, - "cache_ttl": 300, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LRU, - "cache_size": 128, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 128, + "cache_config": CacheConfiguration(max_size=128), + "single_connection_client": True, }, { "use_cache": True, - "cache_eviction": EvictionPolicy.RANDOM, - "cache_size": 128, + "cache_config": CacheConfiguration(max_size=128), + "single_connection_client": False, }, ], - ids=["TTL", "LRU", "LFU", "RANDOM"], + ids=["single", "pool"], indirect=True, ) @pytest.mark.onlynoncluster - def test_get_from_custom_cache(self, request, r, r2): - expected_policy = EvictionPolicy(request.node.callspec.id) + def test_get_from_default_cache(self, r, r2): cache = r.get_cache() - assert expected_policy == cache.eviction_policy + assert isinstance(cache.get_eviction_policy(), LRUPolicy) + assert cache.get_max_size() == 128 # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": False, }, ], @@ -138,25 +143,28 @@ def test_health_check_invalidate_cache(self, r): # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r.set("foo", "barbar") # Wait for health check time.sleep(2) # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": True, }, { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": False, }, ], @@ -171,55 +179,26 @@ def test_cache_clears_on_disconnect(self, r, cache): # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # Force disconnection r.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert cache.currsize == 0 - - @pytest.mark.parametrize( - "r", - [ - {"use_cache": True, "cache_size": 3, "single_connection_client": True}, - {"use_cache": True, "cache_size": 3, "single_connection_client": False}, - ], - ids=["single", "pool"], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_cache_lru_eviction(self, r, cache): - cache = r.get_cache() - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + assert len(cache.get_collection()) == 0 @pytest.mark.parametrize( "r", [ { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_ttl": 1, + "cache_config": CacheConfiguration(max_size=3), "single_connection_client": True, }, { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_ttl": 1, + "cache_config": CacheConfiguration(max_size=3), "single_connection_client": False, }, ], @@ -227,40 +206,7 @@ def test_cache_lru_eviction(self, r, cache): indirect=True, ) @pytest.mark.onlynoncluster - def test_cache_ttl(self, r): - cache = r.get_cache() - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize( - "r", - [ - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 3, - "single_connection_client": True, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 3, - "single_connection_client": False, - }, - ], - ids=["single", "pool"], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_cache_lfu_eviction(self, r): + def test_cache_lru_eviction(self, r, cache): cache = r.get_cache() # add 3 keys to redis r.set("foo", "bar") @@ -270,29 +216,37 @@ def test_cache_lfu_eviction(self, r): assert r.get("foo") == b"bar" assert r.get("foo2") == b"bar2" assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" + # get the 3 keys from local cache + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo2",))).cache_value + == b"bar2" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo3",))).cache_value + == b"bar3" + ) # add 1 more key to redis (exceed the max size) r.set("foo4", "bar4") assert r.get("foo4") == b"bar4" - # test the eviction policy - assert cache.currsize == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None + # the first key is not in the local cache anymore + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None + assert len(cache.get_collection()) == 3 @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": True, }, { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": False, }, ], @@ -306,19 +260,19 @@ def test_cache_ignore_not_allowed_command(self, r): assert r.hset("foo", "bar", "baz") # get random field assert r.hrandfield("foo") == b"bar" - assert cache.get(("HRANDFIELD", "foo")) is None + assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": True, }, { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": False, }, ], @@ -335,7 +289,10 @@ def test_cache_invalidate_all_related_responses(self, r): res = r.mget("foo", "bar") # Make sure that replies was cached assert res == [b"bar", b"foo"] - assert cache.get(("MGET", "foo", "bar")) == res + assert ( + cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))).cache_value + == res + ) # Make sure that objects are immutable. another_res = r.mget("foo", "bar") @@ -346,20 +303,23 @@ def test_cache_invalidate_all_related_responses(self, r): # all associated cached entries was removed assert r.set("foo", "baz") assert r.get("foo") == b"baz" - assert cache.get(("MGET", "foo", "bar")) is None - assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"baz" + ) @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": True, }, { - "cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), "single_connection_client": False, }, ], @@ -378,14 +338,23 @@ def test_cache_flushed_on_server_flush(self, r): assert r.get("foo") == b"bar" assert r.get("bar") == b"foo" assert r.get("baz") == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - assert cache.get(("GET", "baz")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("bar",))).cache_value + == b"foo" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("baz",))).cache_value + == b"bar" + ) # Flush server and trying to access cached entry assert r.flushall() assert r.get("foo") is None - assert cache.currsize == 0 + assert len(cache.get_collection()) == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -394,202 +363,195 @@ def test_cache_flushed_on_server_flush(self, r): class TestClusterCache: @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], + [ + { + "use_cache": True, + "cache": DefaultCache(CacheConfiguration(max_size=128)), + } + ], indirect=True, ) - def test_get_from_cache(self, r, r2): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + @pytest.mark.onlycluster + def test_get_from_cache(self, r): + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) - r2.set("foo", "barbar") + r.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_size": 128, - "cache_ttl": 300, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LRU, - "cache_size": 128, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 128, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.RANDOM, - "cache_size": 128, + "cache_config": CacheConfiguration(max_size=128), }, ], - ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True, ) - def test_get_from_custom_cache(self, request, r, r2): - expected_policy = EvictionPolicy[request.node.callspec.id] + def test_get_from_custom_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() - assert expected_policy == cache.eviction_policy + assert isinstance(cache.get_eviction_policy(), LRUPolicy) + assert cache.get_max_size() == 128 # add key to redis assert r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), + }, + ], indirect=True, ) @pytest.mark.onlycluster def test_health_check_invalidate_cache(self, r, r2): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Wait for health check time.sleep(2) # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(TTLCache(128, 300)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), + }, + ], indirect=True, ) @pytest.mark.onlycluster def test_cache_clears_on_disconnect(self, r, r2): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() # add key to redis r.set("foo", "bar") # get key from redis and save in local cache assert r.get("foo") == b"bar" # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # Force disconnection r.nodes_manager.get_node_from_slot( - 10 + 12000 ).redis_connection.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert cache.currsize == 0 + assert len(cache.get_collection()) == 0 @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(LRUCache(3)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=3), + }, + ], indirect=True, ) @pytest.mark.onlycluster def test_cache_lru_eviction(self, r): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") + r.set("foo{slot}", "bar") + r.set("foo2{slot}", "bar2") + r.set("foo3{slot}", "bar3") # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" + assert r.get("foo{slot}") == b"bar" + assert r.get("foo2{slot}") == b"bar2" + assert r.get("foo3{slot}") == b"bar3" # get the 3 keys from local cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) == b"bar2" - assert cache.get(("GET", "foo3")) == b"bar3" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo2{slot}",))).cache_value + == b"bar2" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo3{slot}",))).cache_value + == b"bar3" + ) # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # the first key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": CacheToolsAdapter(TTLCache(maxsize=128, ttl=1)), "use_cache": True}], - indirect=True, - ) - @pytest.mark.onlycluster - def test_cache_ttl(self, r): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" - # wait for the key to expire - time.sleep(1) - # the key is not in the local cache anymore - assert cache.get(("GET", "foo")) is None + r.set("foo4{slot}", "bar4") + assert r.get("foo4{slot}") == b"bar4" + # the first key is not in the local cache_data anymore + assert cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))) is None @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(LFUCache(3)), "use_cache": True}], - indirect=True, - ) - @pytest.mark.onlycluster - def test_cache_lfu_eviction(self, r): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() - # add 3 keys to redis - r.set("foo", "bar") - r.set("foo2", "bar2") - r.set("foo3", "bar3") - # get 3 keys from redis and save in local cache - assert r.get("foo") == b"bar" - assert r.get("foo2") == b"bar2" - assert r.get("foo3") == b"bar3" - # change the order of the keys in the cache - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo3")) == b"bar3" - # add 1 more key to redis (exceed the max size) - r.set("foo4", "bar4") - assert r.get("foo4") == b"bar4" - # test the eviction policy - assert cache.currsize == 3 - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "foo2")) is None - - @pytest.mark.parametrize( - "r", - [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), + }, + ], indirect=True, ) @pytest.mark.onlycluster def test_cache_ignore_not_allowed_command(self, r): - cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() + cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() # add fields to hash assert r.hset("foo", "bar", "baz") # get random field assert r.hrandfield("foo") == b"bar" - assert cache.get(("HRANDFIELD", "foo")) is None + assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), + }, + ], indirect=True, ) @pytest.mark.onlycluster @@ -601,40 +563,64 @@ def test_cache_invalidate_all_related_responses(self, r, cache): # Make sure that replies was cached assert r.mget("foo{slot}", "bar{slot}") == [b"bar", b"foo"] - assert cache.get(("MGET", "foo{slot}", "bar{slot}")) == [b"bar", b"foo"] + assert cache.get( + CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")), + ).cache_value == [b"bar", b"foo"] # Invalidate one of the keys and make sure # that all associated cached entries was removed assert r.set("foo{slot}", "baz") assert r.get("foo{slot}") == b"baz" - assert cache.get(("MGET", "foo{slot}", "bar{slot}")) is None - assert cache.get(("GET", "foo{slot}")) == b"baz" + assert ( + cache.get( + CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")), + ) + is None + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"baz" + ) @pytest.mark.parametrize( "r", - [{"cache": CacheToolsAdapter(LRUCache(maxsize=128)), "use_cache": True}], + [ + { + "use_cache": True, + "cache_config": CacheConfiguration(max_size=128), + }, + ], indirect=True, ) @pytest.mark.onlycluster def test_cache_flushed_on_server_flush(self, r, cache): cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache() # Add keys - assert r.set("foo", "bar") - assert r.set("bar", "foo") - assert r.set("baz", "bar") + assert r.set("foo{slot}", "bar") + assert r.set("bar{slot}", "foo") + assert r.set("baz{slot}", "bar") # Make sure that replies was cached - assert r.get("foo") == b"bar" - assert r.get("bar") == b"foo" - assert r.get("baz") == b"bar" - assert cache.get(("GET", "foo")) == b"bar" - assert cache.get(("GET", "bar")) == b"foo" - assert cache.get(("GET", "baz")) == b"bar" + assert r.get("foo{slot}") == b"bar" + assert r.get("bar{slot}") == b"foo" + assert r.get("baz{slot}") == b"bar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value + == b"bar" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("bar{slot}",))).cache_value + == b"foo" + ) + assert ( + cache.get(CacheKey(command="GET", redis_keys=("baz{slot}",))).cache_value + == b"bar" + ) # Flush server and trying to access cached entry assert r.flushall() - assert r.get("foo") is None - assert cache.currsize == 0 + assert r.get("foo{slot}") is None + assert len(cache.get_collection()) == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -645,7 +631,7 @@ class TestSentinelCache: "sentinel_setup", [ { - "cache": CacheToolsAdapter(LRUCache(maxsize=128)), + "cache": DefaultCache(CacheConfiguration(max_size=128)), "use_cache": True, "force_master_ip": "localhost", } @@ -656,68 +642,61 @@ class TestSentinelCache: def test_get_from_cache(self, master): cache = master.get_cache() master.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert master.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) master.set("foo", "barbar") # get key from redis assert master.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_size": 128, - "cache_ttl": 300, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LRU, - "cache_size": 128, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 128, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.RANDOM, - "cache_size": 128, + "cache_config": CacheConfiguration(max_size=128), }, ], - ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True, ) - def test_get_from_custom_cache(self, request, r, r2): - expected_policy = EvictionPolicy[request.node.callspec.id] + def test_get_from_default_cache(self, r, r2): cache = r.get_cache() - assert expected_policy == cache.eviction_policy + assert isinstance(cache.get_eviction_policy(), LRUPolicy) # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") - # Retrieves a new value from server and cache it + # Retrieves a new value from server and cache_data it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "sentinel_setup", [ { - "cache": CacheToolsAdapter(LRUCache(maxsize=128)), + "cache_config": CacheConfiguration(max_size=128), "use_cache": True, "force_master_ip": "localhost", } @@ -729,22 +708,25 @@ def test_health_check_invalidate_cache(self, master, cache): cache = master.get_cache() # add key to redis master.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert master.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) master.set("foo", "barbar") # Wait for health check time.sleep(2) # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @pytest.mark.parametrize( "sentinel_setup", [ { - "cache": CacheToolsAdapter(LRUCache(maxsize=128)), + "cache_config": CacheConfiguration(max_size=128), "use_cache": True, "force_master_ip": "localhost", } @@ -756,14 +738,17 @@ def test_cache_clears_on_disconnect(self, master, cache): cache = master.get_cache() # add key to redis master.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert master.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # Force disconnection master.connection_pool.get_connection("_").disconnect() - # Make sure cache is empty - assert cache.currsize == 0 + # Make sure cache_data is empty + assert len(cache.get_collection()) == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -774,7 +759,7 @@ class TestSSLCache: "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), + "cache": DefaultCache(CacheConfiguration(max_size=128)), "use_cache": True, "ssl": True, } @@ -786,72 +771,62 @@ def test_get_from_cache(self, r, r2, cache): cache = r.get_cache() # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) assert r2.set("foo", "barbar") - # Retrieves a new value from server and cache it + # Retrieves a new value from server and cache_data it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { "use_cache": True, - "cache_eviction": EvictionPolicy.TTL, - "cache_size": 128, - "cache_ttl": 300, - "ssl": True, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LRU, - "cache_size": 128, - "ssl": True, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.LFU, - "cache_size": 128, - "ssl": True, - }, - { - "use_cache": True, - "cache_eviction": EvictionPolicy.RANDOM, - "cache_size": 128, + "cache_config": CacheConfiguration(max_size=128), "ssl": True, }, ], - ids=["TTL", "LRU", "LFU", "RANDOM"], indirect=True, ) - def test_get_from_custom_cache(self, request, r, r2): - expected_policy = EvictionPolicy[request.node.callspec.id] + def test_get_from_custom_cache(self, r, r2): cache = r.get_cache() - assert expected_policy == cache.eviction_policy + assert isinstance(cache.get_eviction_policy(), LRUPolicy) # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") - # Retrieves a new value from server and cache it + # Retrieves a new value from server and cache_data it assert r.get("foo") == b"barbar" # Make sure that new value was cached - assert cache.get(("GET", "foo")) == b"barbar" + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"barbar" + ) @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), + "cache_config": CacheConfiguration(max_size=128), "use_cache": True, "ssl": True, } @@ -863,22 +838,25 @@ def test_health_check_invalidate_cache(self, r, r2): cache = r.get_cache() # add key to redis r.set("foo", "bar") - # get key from redis and save in local cache + # get key from redis and save in local cache_data assert r.get("foo") == b"bar" - # get key from local cache - assert cache.get(("GET", "foo")) == b"bar" + # get key from local cache_data + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"bar" + ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Wait for health check time.sleep(2) # Make sure that value was invalidated - assert cache.get(("GET", "foo")) is None + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @pytest.mark.parametrize( "r", [ { - "cache": CacheToolsAdapter(TTLCache(128, 300)), + "cache_config": CacheConfiguration(max_size=128), "use_cache": True, "ssl": True, } @@ -894,23 +872,342 @@ def test_cache_invalidate_all_related_responses(self, r): # Make sure that replies was cached assert r.mget("foo", "bar") == [b"bar", b"foo"] - assert cache.get(("MGET", "foo", "bar")) == [b"bar", b"foo"] + assert cache.get( + CacheKey(command="MGET", redis_keys=("foo", "bar")) + ).cache_value == [b"bar", b"foo"] # Invalidate one of the keys and make sure # that all associated cached entries was removed assert r.set("foo", "baz") assert r.get("foo") == b"baz" - assert cache.get(("MGET", "foo", "bar")) is None - assert cache.get(("GET", "foo")) == b"baz" + assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None + assert ( + cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value + == b"baz" + ) + + +class TestUnitDefaultCache: + def test_get_eviction_policy(self): + cache = DefaultCache(CacheConfiguration(max_size=5)) + assert isinstance(cache.get_eviction_policy(), LRUPolicy) + + def test_get_max_size(self): + cache = DefaultCache(CacheConfiguration(max_size=5)) + assert cache.get_max_size() == 5 + + @pytest.mark.parametrize( + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True + ) + def test_set_non_existing_cache_key(self, cache_key): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + assert cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + ) + ) + assert cache.get(cache_key).cache_value == b"val" + + @pytest.mark.parametrize( + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True + ) + def test_set_updates_existing_cache_key(self, cache_key): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + assert cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + ) + ) + assert cache.get(cache_key).cache_value == b"val" + + cache.set( + CacheEntry( + cache_key=cache_key, + cache_value=b"new_val", + status=CacheEntryStatus.VALID, + ) + ) + assert cache.get(cache_key).cache_value == b"new_val" + + @pytest.mark.parametrize( + "cache_key", [{"command": "HRANDFIELD", "redis_keys": ("bar",)}], indirect=True + ) + def test_set_does_not_store_not_allowed_key(self, cache_key): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + assert not cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + ) + ) + + def test_set_evict_lru_cache_key_on_reaching_max_size(self): + cache = DefaultCache(CacheConfiguration(max_size=3)) + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + ) + ) + + # Accessing key in the order that it makes 2nd key LRU + assert cache.get(cache_key1).cache_value == b"bar" + assert cache.get(cache_key2).cache_value == b"bar1" + assert cache.get(cache_key3).cache_value == b"bar2" + assert cache.get(cache_key1).cache_value == b"bar" + + cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) + assert cache.set( + CacheEntry( + cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID + ) + ) + + # Make sure that new key was added and 2nd is evicted + assert cache.get(cache_key4).cache_value == b"bar3" + assert cache.get(cache_key2) is None + + @pytest.mark.parametrize( + "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True + ) + def test_get_return_correct_value(self, cache_key): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + assert cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + ) + ) + assert cache.get(cache_key).cache_value == b"val" + + wrong_key = CacheKey(command="HGET", redis_keys=("foo",)) + assert cache.get(wrong_key) is None + + result = cache.get(cache_key) + + # Make sure that result is immutable. + assert result != cache.get(cache_key) + + def test_delete_by_cache_keys_removes_associated_entries(self): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + ) + ) + + assert cache.delete_by_cache_keys([cache_key1, cache_key2, cache_key4]) == [ + True, + True, + False, + ] + assert len(cache.get_collection()) == 1 + assert cache.get(cache_key3).cache_value == b"bar2" + + def test_delete_by_redis_keys_removes_associated_entries(self): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="MGET", redis_keys=("foo", "foo3")) + cache_key4 = CacheKey(command="MGET", redis_keys=("foo2", "foo3")) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID + ) + ) + + assert cache.delete_by_redis_keys([b"foo", b"foo1"]) == [True, True, True] + assert len(cache.get_collection()) == 1 + assert cache.get(cache_key4).cache_value == b"bar3" + + def test_flush(self): + cache = DefaultCache(CacheConfiguration(max_size=5)) + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) + cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) + + # Set 3 different keys + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + ) + ) + + assert cache.flush() == 3 + assert len(cache.get_collection()) == 0 + + +class TestUnitLRUPolicy: + def test_type(self): + policy = LRUPolicy() + assert policy.type == EvictionPolicyType.time_based + + def test_evict_next(self): + cache = DefaultCache( + CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.get_eviction_policy() + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + ) + ) + + assert policy.evict_next() == cache_key1 + assert cache.get(cache_key1) is None + + def test_evict_many(self): + cache = DefaultCache( + CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.get_eviction_policy() + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + cache_key3 = CacheKey(command="GET", redis_keys=("baz",)) + + assert cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + ) + ) + assert cache.set( + CacheEntry( + cache_key=cache_key3, cache_value=b"baz", status=CacheEntryStatus.VALID + ) + ) + + assert policy.evict_many(2) == [cache_key1, cache_key2] + assert cache.get(cache_key1) is None + assert cache.get(cache_key2) is None + + with pytest.raises(ValueError, match="Evictions count is above cache size"): + policy.evict_many(99) + + def test_touch(self): + cache = DefaultCache( + CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + ) + policy = cache.get_eviction_policy() + + cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) + cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) + + cache.set( + CacheEntry( + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + ) + ) + + assert cache.get_collection().popitem(last=True)[0] == cache_key2 + cache.set( + CacheEntry( + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + ) + ) + + policy.touch(cache_key1) + assert cache.get_collection().popitem(last=True)[0] == cache_key1 + + def test_throws_error_on_invalid_cache(self): + policy = LRUPolicy() + + with pytest.raises( + ValueError, match="Eviction policy should be associated with valid cache." + ): + policy.evict_next() + + policy.cache = "wrong_type" + + with pytest.raises( + ValueError, match="Eviction policy should be associated with valid cache." + ): + policy.evict_next() class TestUnitCacheConfiguration: - TTL = 20 MAX_SIZE = 100 - EVICTION_POLICY = EvictionPolicy.TTL - - def test_get_ttl(self, cache_conf: CacheConfiguration): - assert self.TTL == cache_conf.get_ttl() + EVICTION_POLICY = EvictionPolicy.LRU def test_get_max_size(self, cache_conf: CacheConfiguration): assert self.MAX_SIZE == cache_conf.get_max_size() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 4ad88e7c08..6219e3bb85 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2692,7 +2692,7 @@ def test_init_slots_cache_slots_collision(self, request): def create_mocked_redis_node(host, port, **kwargs): """ - Helper function to return custom slots cache data from + Helper function to return custom slots cache_data data from different redis nodes """ if port == 7000: @@ -2733,7 +2733,7 @@ def execute_command(*args, **kwargs): node_2 = ClusterNode("127.0.0.1", 7001) RedisCluster(startup_nodes=[node_1, node_2]) assert str(ex.value).startswith( - "startup_nodes could not agree on a valid slots cache" + "startup_nodes could not agree on a valid slots cache_data" ), str(ex.value) def test_cluster_one_instance(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index e34e431cb0..8e98d4b241 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,11 +5,20 @@ import pytest import redis -from cachetools import LRUCache, TTLCache from redis import ConnectionPool, Redis from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff -from redis.cache import CacheInterface, CacheToolsAdapter, EvictionPolicy +from redis.cache import ( + CacheConfiguration, + CacheEntry, + CacheEntryStatus, + CacheInterface, + CacheKey, + DefaultCache, + EvictionPolicy, + EvictionPolicyInterface, + LRUPolicy, +) from redis.connection import ( CacheProxyConnection, Connection, @@ -388,9 +397,7 @@ def test_creates_cache_with_custom_cache_factory( connection_pool = ConnectionPool( protocol=3, use_cache=True, - cache_size=100, - cache_ttl=20, - cache_eviction=EvictionPolicy.TTL, + cache_config=CacheConfiguration(max_size=5), cache_factory=mock_cache_factory, ) @@ -399,16 +406,12 @@ def test_creates_cache_with_custom_cache_factory( def test_creates_cache_with_given_configuration(self, mock_cache): connection_pool = ConnectionPool( - protocol=3, - use_cache=True, - cache_size=100, - cache_ttl=20, - cache_eviction=EvictionPolicy.TTL, + protocol=3, use_cache=True, cache_config=CacheConfiguration(max_size=100) ) assert isinstance(connection_pool.cache, CacheInterface) - assert connection_pool.cache.maxsize == 100 - assert connection_pool.cache.eviction_policy == EvictionPolicy.TTL + assert connection_pool.cache.get_max_size() == 100 + assert isinstance(connection_pool.cache.get_eviction_policy(), LRUPolicy) connection_pool.disconnect() def test_make_connection_proxy_connection_on_given_cache(self): @@ -420,18 +423,22 @@ def test_make_connection_proxy_connection_on_given_cache(self): class TestUnitCacheProxyConnection: def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): - cache = LRUCache(100) - cache["key"] = "value" - assert cache["key"] == "value" + cache = DefaultCache(10, eviction_policy=LRUPolicy()) + cache_key = CacheKey(command="GET", redis_keys=("foo",)) + + cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"bar", status=CacheEntryStatus.VALID + ) + ) + assert cache.get(cache_key).cache_value == b"bar" mock_connection.disconnect.return_value = None mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" - proxy_connection = CacheProxyConnection( - mock_connection, CacheToolsAdapter(cache), cache_conf - ) + proxy_connection = CacheProxyConnection(mock_connection, cache, cache_conf) proxy_connection.disconnect() - assert cache.currsize == 0 + assert len(cache.get_collection()) == 0 From d78c5d3881dad6ffed1c046cb7a43892646f8ad3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 29 Aug 2024 14:22:00 +0300 Subject: [PATCH 47/78] Updated test cases --- redis/cache.py | 4 ++-- tests/test_connection.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 8903569f99..9b85c5776c 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -17,11 +17,11 @@ class EvictionPolicyType(Enum): class CacheKey: - def __init__(self, command: str, redis_keys: tuple[str, ...]): + def __init__(self, command: str, redis_keys: tuple): self.command = command self.redis_keys = redis_keys - def get_redis_keys(self) -> tuple[str, ...]: + def get_redis_keys(self) -> tuple: return self.redis_keys def __hash__(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index 8e98d4b241..0b1f6fb5ad 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -381,7 +381,7 @@ def test_throws_error_on_cache_enable_in_resp2(self): def test_throws_error_on_incorrect_cache_implementation(self): with pytest.raises(ValueError, match="Cache must implement CacheInterface"): - ConnectionPool(protocol=3, use_cache=True, cache=TTLCache(100, 20)) + ConnectionPool(protocol=3, use_cache=True, cache='wrong') def test_returns_custom_cache_implementation(self, mock_cache): connection_pool = ConnectionPool(protocol=3, use_cache=True, cache=mock_cache) @@ -423,7 +423,7 @@ def test_make_connection_proxy_connection_on_given_cache(self): class TestUnitCacheProxyConnection: def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): - cache = DefaultCache(10, eviction_policy=LRUPolicy()) + cache = DefaultCache(CacheConfiguration(max_size=10)) cache_key = CacheKey(command="GET", redis_keys=("foo",)) cache.set( @@ -438,7 +438,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.host = "mock" mock_connection.port = "mock" - proxy_connection = CacheProxyConnection(mock_connection, cache, cache_conf) + proxy_connection = CacheProxyConnection(mock_connection, cache) proxy_connection.disconnect() assert len(cache.get_collection()) == 0 From 9cac761cddf5627d0b685cd08fefe0e81a9e7f59 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 29 Aug 2024 14:24:51 +0300 Subject: [PATCH 48/78] Updated typings --- redis/cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 9b85c5776c..97f6315582 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -70,7 +70,7 @@ def touch(self, cache_key: CacheKey) -> None: class CacheInterface(ABC): @abstractmethod - def get_collection(self) -> OrderedDict[CacheKey, CacheEntry]: + def get_collection(self) -> OrderedDict: pass @abstractmethod @@ -138,7 +138,7 @@ def __init__( self._eviction_policy = self._cache_config.get_eviction_policy().value() self._eviction_policy.cache = self - def get_collection(self) -> OrderedDict[CacheKey, CacheEntry]: + def get_collection(self) -> OrderedDict: return self._cache def get_eviction_policy(self) -> EvictionPolicyInterface: From 43af6eefb02438c3383cc572ca946ce1074c2beb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 29 Aug 2024 14:33:44 +0300 Subject: [PATCH 49/78] Updated types --- redis/cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 97f6315582..7f99aa352d 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum -from typing import Any, Collection, Hashable, List, Optional +from typing import Any, Collection, Hashable, List, Optional, Union class CacheEntryStatus(Enum): @@ -82,7 +82,7 @@ def get_max_size(self) -> int: pass @abstractmethod - def get(self, key: CacheKey) -> CacheEntry | None: + def get(self, key: CacheKey) -> Union[CacheEntry, None]: pass @abstractmethod @@ -159,7 +159,7 @@ def set(self, entry: CacheEntry) -> bool: return True - def get(self, key: CacheKey) -> CacheEntry | None: + def get(self, key: CacheKey) -> Union[CacheEntry, None]: entry = self._cache.get(key, None) if entry is None: From fa1a43176fcf55a8ff6c3eebca43d89a64dd1e8f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 29 Aug 2024 17:40:16 +0300 Subject: [PATCH 50/78] Revert changes --- tests/test_cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 6219e3bb85..c4b3188050 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -2733,7 +2733,7 @@ def execute_command(*args, **kwargs): node_2 = ClusterNode("127.0.0.1", 7001) RedisCluster(startup_nodes=[node_1, node_2]) assert str(ex.value).startswith( - "startup_nodes could not agree on a valid slots cache_data" + "startup_nodes could not agree on a valid slots cache" ), str(ex.value) def test_cluster_one_instance(self): From 0ffc298ba06f222f59f75078d9cc502f625b1f36 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 30 Aug 2024 15:07:20 +0300 Subject: [PATCH 51/78] Removed use_cache, make health_check configurable, removed retry logic around can_read() --- redis/cache.py | 64 +++++++++----- redis/client.py | 12 +-- redis/cluster.py | 15 +--- redis/connection.py | 11 ++- redis/retry.py | 32 ------- tests/conftest.py | 8 +- tests/test_cache.py | 177 ++++++++++++++++++--------------------- tests/test_connection.py | 86 ++++++++++++++++--- 8 files changed, 215 insertions(+), 190 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 7f99aa352d..6320682e97 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -39,6 +39,12 @@ def __init__( self.cache_value = cache_value self.status = status + def __hash__(self): + return hash((self.cache_key, self.cache_value, self.status)) + + def __eq__(self, other): + return hash(self) == hash(other) + class EvictionPolicyInterface(ABC): @property @@ -68,63 +74,71 @@ def touch(self, cache_key: CacheKey) -> None: pass -class CacheInterface(ABC): +class CacheConfigurationInterface(ABC): @abstractmethod - def get_collection(self) -> OrderedDict: + def get_cache_class(self): pass @abstractmethod - def get_eviction_policy(self) -> EvictionPolicyInterface: + def get_max_size(self) -> int: pass @abstractmethod - def get_max_size(self) -> int: + def get_eviction_policy(self): pass @abstractmethod - def get(self, key: CacheKey) -> Union[CacheEntry, None]: + def get_health_check_interval(self) -> float: pass @abstractmethod - def set(self, entry: CacheEntry) -> bool: + def is_exceeds_max_size(self, count: int) -> bool: pass @abstractmethod - def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: + def is_allowed_to_cache(self, command: str) -> bool: pass + +class CacheInterface(ABC): @abstractmethod - def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: + def get_collection(self) -> OrderedDict: pass @abstractmethod - def flush(self) -> int: + def get_config(self) -> CacheConfigurationInterface: pass @abstractmethod - def is_cachable(self, key: CacheKey) -> bool: + def get_eviction_policy(self) -> EvictionPolicyInterface: pass + @abstractmethod + def get_size(self) -> int: + pass -class CacheConfigurationInterface(ABC): @abstractmethod - def get_cache_class(self): + def get(self, key: CacheKey) -> Union[CacheEntry, None]: pass @abstractmethod - def get_max_size(self) -> int: + def set(self, entry: CacheEntry) -> bool: pass @abstractmethod - def get_eviction_policy(self): + def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: pass @abstractmethod - def is_exceeds_max_size(self, count: int) -> bool: + def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: pass @abstractmethod - def is_allowed_to_cache(self, command: str) -> bool: + def flush(self) -> int: + pass + + @abstractmethod + def is_cachable(self, key: CacheKey) -> bool: pass @@ -141,11 +155,14 @@ def __init__( def get_collection(self) -> OrderedDict: return self._cache + def get_config(self) -> CacheConfigurationInterface: + return self._cache_config + def get_eviction_policy(self) -> EvictionPolicyInterface: return self._eviction_policy - def get_max_size(self) -> int: - return self._cache_config.get_max_size() + def get_size(self) -> int: + return len(self._cache) def set(self, entry: CacheEntry) -> bool: if not self.is_cachable(entry.cache_key): @@ -256,7 +273,7 @@ class EvictionPolicy(Enum): LRU = LRUPolicy -class CacheConfiguration(CacheConfigurationInterface): +class CacheConfig(CacheConfigurationInterface): DEFAULT_CACHE_CLASS = DefaultCache DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU DEFAULT_MAX_SIZE = 10000 @@ -343,10 +360,12 @@ def __init__( max_size: int = DEFAULT_MAX_SIZE, cache_class: Any = DEFAULT_CACHE_CLASS, eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, + health_check_interval: float = 2.0, ): self._cache_class = cache_class self._max_size = max_size self._eviction_policy = eviction_policy + self._health_check_interval = health_check_interval def get_cache_class(self): return self._cache_class @@ -357,6 +376,9 @@ def get_max_size(self) -> int: def get_eviction_policy(self) -> EvictionPolicy: return self._eviction_policy + def get_health_check_interval(self) -> float: + return self._health_check_interval + def is_exceeds_max_size(self, count: int) -> bool: return count > self._max_size @@ -371,11 +393,11 @@ def get_cache(self) -> CacheInterface: class CacheFactory(CacheFactoryInterface): - def __init__(self, cache_config: Optional[CacheConfiguration] = None): + def __init__(self, cache_config: Optional[CacheConfig] = None): self._config = cache_config if self._config is None: - self._config = CacheConfiguration() + self._config = CacheConfig() def get_cache(self) -> CacheInterface: cache_class = self._config.get_cache_class() diff --git a/redis/client.py b/redis/client.py index 220fbf238b..ec2edfcb35 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,7 +13,7 @@ _RedisCallbacksRESP3, bool_ok, ) -from redis.cache import CacheConfiguration, CacheInterface +from redis.cache import CacheConfig, CacheInterface from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -142,12 +142,10 @@ class initializer. In the case of conflicting arguments, querystring """ single_connection_client = kwargs.pop("single_connection_client", False) - use_cache = kwargs.pop("use_cache", False) connection_pool = ConnectionPool.from_url(url, **kwargs) client = cls( connection_pool=connection_pool, single_connection_client=single_connection_client, - use_cache=use_cache, ) client.auto_close_connection_pool = True return client @@ -213,9 +211,8 @@ def __init__( redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfiguration] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -308,10 +305,9 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) - if use_cache and protocol in [3, "3"]: + if (cache_config or cache) and protocol in [3, "3"]: kwargs.update( { - "use_cache": use_cache, "cache": cache, "cache_config": cache_config, } @@ -323,7 +319,7 @@ def __init__( self.connection_pool = connection_pool - if use_cache and self.connection_pool.get_protocol() not in [3, "3"]: + if (cache_config or cache) and self.connection_pool.get_protocol() not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") self.connection = None diff --git a/redis/cluster.py b/redis/cluster.py index 1635d0f8b7..a3142028fc 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import CacheConfiguration, CacheInterface +from redis.cache import CacheConfig, CacheInterface from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -168,7 +168,6 @@ def parse_cluster_myshardid(resp, **options): "ssl_password", "unix_socket_path", "username", - "use_cache", "cache", "cache_config", ) @@ -504,9 +503,8 @@ def __init__( dynamic_startup_nodes: bool = True, url: Optional[str] = None, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfiguration] = None, + cache_config: Optional[CacheConfig] = None, **kwargs, ): """ @@ -631,7 +629,7 @@ def __init__( kwargs.get("decode_responses", False), ) protocol = kwargs.get("protocol", None) - if use_cache and protocol not in [3, "3"]: + if (cache_config or cache) and protocol not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") self.cluster_error_retry_attempts = cluster_error_retry_attempts @@ -646,7 +644,6 @@ def __init__( require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, - use_cache=use_cache, cache=cache, cache_config=cache_config, **kwargs, @@ -1328,9 +1325,8 @@ def __init__( dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - use_cache: bool = False, cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfiguration] = None, + cache_config: Optional[CacheConfig] = None, **kwargs, ): self.nodes_cache = {} @@ -1343,7 +1339,6 @@ def __init__( self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap - self.use_cache = use_cache self.cache = cache self.cache_config = cache_config self._moved_exception = None @@ -1489,7 +1484,6 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - kwargs.update({"use_cache": self.use_cache}) kwargs.update({"cache": self.cache}) kwargs.update({"cache_config": self.cache_config}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) @@ -1497,7 +1491,6 @@ def create_redis_node(self, host, port, **kwargs): r = Redis( host=host, port=port, - use_cache=self.use_cache, cache=self.cache, cache_config=self.cache_config, **kwargs, diff --git a/redis/connection.py b/redis/connection.py index 95e7cb074d..1d13a25c46 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -13,7 +13,7 @@ from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( - CacheConfiguration, + CacheConfig, CacheEntry, CacheEntryStatus, CacheFactory, @@ -879,7 +879,7 @@ def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) def _process_pending_invalidations(self): - while self.retry.call_with_retry_on_false(lambda: self.can_read()): + while self.can_read(): self._conn.read_response(push_request=True) def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): @@ -1251,13 +1251,12 @@ def __init__( self.connection_kwargs = connection_kwargs self.max_connections = max_connections self.cache = None - self._cache_conf = None self._cache_factory = cache_factory self._scheduler = None self._hc_cancel_event = None self._hc_thread = None - if connection_kwargs.get("use_cache"): + if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): if connection_kwargs.get("protocol") not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") @@ -1278,7 +1277,6 @@ def __init__( self._scheduler = Scheduler() - connection_kwargs.pop("use_cache", None) connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) @@ -1494,8 +1492,9 @@ def run_scheduled_healthcheck(self) -> None: # Run scheduled healthcheck to avoid stale invalidations in idle connections. if self.cache is not None and self._scheduler is not None: self._hc_cancel_event = threading.Event() + hc_interval = self.cache.get_config().get_health_check_interval() self._hc_thread = self._scheduler.run_with_interval( - self._perform_health_check, 2, self._hc_cancel_event + self._perform_health_check, hc_interval, self._hc_cancel_event ) def _perform_health_check(self, done: threading.Event) -> None: diff --git a/redis/retry.py b/redis/retry.py index 218159f861..1b0fe113f8 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -78,35 +78,3 @@ def call_with_retry( backoff = self._backoff.compute(failures) if backoff > 0: sleep(backoff) - - def call_with_retry_on_false( - self, - do: Callable[[], T], - on_false: Optional[Callable[[], T]] = None, - max_retries: Optional[int] = 3, - timeout: Optional[float] = 0, - exponent: Optional[int] = 2, - ) -> bool: - """ - Execute an operation that returns boolean value with retry - logic in case if false value been returned. - `do`: the operation to call. Expects no argument. - `on_false`: Callback to be executed on retry fail. - """ - res = do() - - if res: - return res - - if on_false is not None: - on_false() - - if max_retries > 0: - if timeout > 0: - time.sleep(timeout) - - return self.call_with_retry_on_false( - do, on_false, max_retries - 1, timeout * exponent, exponent - ) - - return False diff --git a/tests/conftest.py b/tests/conftest.py index 3b73bedf95..0755fd390e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ from redis import Sentinel from redis.backoff import NoBackoff from redis.cache import ( - CacheConfiguration, + CacheConfig, CacheFactoryInterface, CacheInterface, CacheKey, @@ -439,7 +439,6 @@ def sentinel_setup(request): for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(",")) ] kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {} - use_cache = request.param.get("use_cache", False) cache = request.param.get("cache", None) cache_config = request.param.get("cache_config", None) force_master_ip = request.param.get("force_master_ip", None) @@ -447,7 +446,6 @@ def sentinel_setup(request): sentinel_endpoints, force_master_ip=force_master_ip, socket_timeout=0.1, - use_cache=use_cache, cache=cache, cache_config=cache_config, protocol=3, @@ -543,8 +541,8 @@ def master_host(request): @pytest.fixture() -def cache_conf() -> CacheConfiguration: - return CacheConfiguration(max_size=100, eviction_policy=EvictionPolicy.LRU) +def cache_conf() -> CacheConfig: + return CacheConfig(max_size=100, eviction_policy=EvictionPolicy.LRU) @pytest.fixture() diff --git a/tests/test_cache.py b/tests/test_cache.py index e4298264af..d16009e270 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -3,7 +3,7 @@ import pytest import redis from redis.cache import ( - CacheConfiguration, + CacheConfig, CacheEntry, CacheEntryStatus, CacheKey, @@ -18,7 +18,6 @@ @pytest.fixture() def r(request): - use_cache = request.param.get("use_cache", False) cache = request.param.get("cache") cache_config = request.param.get("cache_config") kwargs = request.param.get("kwargs", {}) @@ -31,7 +30,6 @@ def r(request): protocol=protocol, ssl=ssl, single_connection_client=single_connection_client, - use_cache=use_cache, cache=cache, cache_config=cache_config, **kwargs, @@ -41,19 +39,17 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +#@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", [ { - "cache": DefaultCache(CacheConfiguration(max_size=5)), - "use_cache": True, + "cache": DefaultCache(CacheConfig(max_size=5)), "single_connection_client": True, }, { - "cache": DefaultCache(CacheConfiguration(max_size=5)), - "use_cache": True, + "cache": DefaultCache(CacheConfig(max_size=5)), "single_connection_client": False, }, ], @@ -86,13 +82,11 @@ def test_get_from_given_cache(self, r, r2): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, ], @@ -103,7 +97,7 @@ def test_get_from_given_cache(self, r, r2): def test_get_from_default_cache(self, r, r2): cache = r.get_cache() assert isinstance(cache.get_eviction_policy(), LRUPolicy) - assert cache.get_max_size() == 128 + assert cache.get_config().get_max_size() == 128 # add key to redis r.set("foo", "bar") @@ -128,8 +122,7 @@ def test_get_from_default_cache(self, r, r2): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), "single_connection_client": False, }, ], @@ -150,7 +143,7 @@ def test_health_check_invalidate_cache(self, r): # change key in redis (cause invalidation) r.set("foo", "barbar") # Wait for health check - time.sleep(2) + time.sleep(1.0) # Make sure that value was invalidated assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @@ -158,13 +151,11 @@ def test_health_check_invalidate_cache(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, ], @@ -186,19 +177,17 @@ def test_cache_clears_on_disconnect(self, r, cache): # Force disconnection r.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert len(cache.get_collection()) == 0 + assert cache.get_size() == 0 @pytest.mark.parametrize( "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=3), + "cache_config": CacheConfig(max_size=3), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=3), + "cache_config": CacheConfig(max_size=3), "single_connection_client": False, }, ], @@ -234,19 +223,17 @@ def test_cache_lru_eviction(self, r, cache): assert r.get("foo4") == b"bar4" # the first key is not in the local cache anymore assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - assert len(cache.get_collection()) == 3 + assert cache.get_size() == 3 @pytest.mark.parametrize( "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, ], @@ -266,13 +253,11 @@ def test_cache_ignore_not_allowed_command(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, ], @@ -313,13 +298,11 @@ def test_cache_invalidate_all_related_responses(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": True, }, { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, ], @@ -354,7 +337,7 @@ def test_cache_flushed_on_server_flush(self, r): # Flush server and trying to access cached entry assert r.flushall() assert r.get("foo") is None - assert len(cache.get_collection()) == 0 + assert cache.get_size() == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -365,8 +348,7 @@ class TestClusterCache: "r", [ { - "use_cache": True, - "cache": DefaultCache(CacheConfiguration(max_size=128)), + "cache": DefaultCache(CacheConfig(max_size=128)), } ], indirect=True, @@ -397,8 +379,7 @@ def test_get_from_cache(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -406,7 +387,7 @@ def test_get_from_cache(self, r): def test_get_from_custom_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() assert isinstance(cache.get_eviction_policy(), LRUPolicy) - assert cache.get_max_size() == 128 + assert cache.get_config().get_max_size() == 128 # add key to redis assert r.set("foo", "bar") @@ -431,8 +412,7 @@ def test_get_from_custom_cache(self, r, r2): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -460,8 +440,7 @@ def test_health_check_invalidate_cache(self, r, r2): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -483,14 +462,13 @@ def test_cache_clears_on_disconnect(self, r, r2): 12000 ).redis_connection.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert len(cache.get_collection()) == 0 + assert cache.get_size() == 0 @pytest.mark.parametrize( "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=3), + "cache_config": CacheConfig(max_size=3), }, ], indirect=True, @@ -529,8 +507,7 @@ def test_cache_lru_eviction(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -548,8 +525,7 @@ def test_cache_ignore_not_allowed_command(self, r): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -586,8 +562,7 @@ def test_cache_invalidate_all_related_responses(self, r, cache): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -620,7 +595,7 @@ def test_cache_flushed_on_server_flush(self, r, cache): # Flush server and trying to access cached entry assert r.flushall() assert r.get("foo{slot}") is None - assert len(cache.get_collection()) == 0 + assert cache.get_size() == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -631,8 +606,7 @@ class TestSentinelCache: "sentinel_setup", [ { - "cache": DefaultCache(CacheConfiguration(max_size=128)), - "use_cache": True, + "cache": DefaultCache(CacheConfig(max_size=128)), "force_master_ip": "localhost", } ], @@ -663,8 +637,7 @@ def test_get_from_cache(self, master): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), }, ], indirect=True, @@ -696,8 +669,7 @@ def test_get_from_default_cache(self, r, r2): "sentinel_setup", [ { - "cache_config": CacheConfiguration(max_size=128), - "use_cache": True, + "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), "force_master_ip": "localhost", } ], @@ -718,7 +690,7 @@ def test_health_check_invalidate_cache(self, master, cache): # change key in redis (cause invalidation) master.set("foo", "barbar") # Wait for health check - time.sleep(2) + time.sleep(1.0) # Make sure that value was invalidated assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @@ -726,8 +698,7 @@ def test_health_check_invalidate_cache(self, master, cache): "sentinel_setup", [ { - "cache_config": CacheConfiguration(max_size=128), - "use_cache": True, + "cache_config": CacheConfig(max_size=128), "force_master_ip": "localhost", } ], @@ -748,7 +719,7 @@ def test_cache_clears_on_disconnect(self, master, cache): # Force disconnection master.connection_pool.get_connection("_").disconnect() # Make sure cache_data is empty - assert len(cache.get_collection()) == 0 + assert cache.get_size() == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -759,8 +730,7 @@ class TestSSLCache: "r", [ { - "cache": DefaultCache(CacheConfiguration(max_size=128)), - "use_cache": True, + "cache": DefaultCache(CacheConfig(max_size=128)), "ssl": True, } ], @@ -780,6 +750,9 @@ def test_get_from_cache(self, r, r2, cache): ) # change key in redis (cause invalidation) assert r2.set("foo", "barbar") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) # Retrieves a new value from server and cache_data it assert r.get("foo") == b"barbar" # Make sure that new value was cached @@ -792,8 +765,7 @@ def test_get_from_cache(self, r, r2, cache): "r", [ { - "use_cache": True, - "cache_config": CacheConfiguration(max_size=128), + "cache_config": CacheConfig(max_size=128), "ssl": True, }, ], @@ -814,6 +786,9 @@ def test_get_from_custom_cache(self, r, r2): ) # change key in redis (cause invalidation) r2.set("foo", "barbar") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) # Retrieves a new value from server and cache_data it assert r.get("foo") == b"barbar" # Make sure that new value was cached @@ -826,8 +801,7 @@ def test_get_from_custom_cache(self, r, r2): "r", [ { - "cache_config": CacheConfiguration(max_size=128), - "use_cache": True, + "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), "ssl": True, } ], @@ -848,7 +822,7 @@ def test_health_check_invalidate_cache(self, r, r2): # change key in redis (cause invalidation) r2.set("foo", "barbar") # Wait for health check - time.sleep(2) + time.sleep(1.0) # Make sure that value was invalidated assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None @@ -856,8 +830,7 @@ def test_health_check_invalidate_cache(self, r, r2): "r", [ { - "cache_config": CacheConfiguration(max_size=128), - "use_cache": True, + "cache_config": CacheConfig(max_size=128), "ssl": True, } ], @@ -879,6 +852,9 @@ def test_cache_invalidate_all_related_responses(self, r): # Invalidate one of the keys and make sure # that all associated cached entries was removed assert r.set("foo", "baz") + # Timeout needed for SSL connection because there's timeout + # between data appears in socket buffer + time.sleep(0.1) assert r.get("foo") == b"baz" assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None assert ( @@ -889,18 +865,22 @@ def test_cache_invalidate_all_related_responses(self, r): class TestUnitDefaultCache: def test_get_eviction_policy(self): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) assert isinstance(cache.get_eviction_policy(), LRUPolicy) def test_get_max_size(self): - cache = DefaultCache(CacheConfiguration(max_size=5)) - assert cache.get_max_size() == 5 + cache = DefaultCache(CacheConfig(max_size=5)) + assert cache.get_config().get_max_size() == 5 + + def test_get_size(self): + cache = DefaultCache(CacheConfig(max_size=5)) + assert cache.get_size() == 0 @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) def test_set_non_existing_cache_key(self, cache_key): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( @@ -913,7 +893,7 @@ def test_set_non_existing_cache_key(self, cache_key): "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) def test_set_updates_existing_cache_key(self, cache_key): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( @@ -935,7 +915,7 @@ def test_set_updates_existing_cache_key(self, cache_key): "cache_key", [{"command": "HRANDFIELD", "redis_keys": ("bar",)}], indirect=True ) def test_set_does_not_store_not_allowed_key(self, cache_key): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) assert not cache.set( CacheEntry( @@ -944,7 +924,7 @@ def test_set_does_not_store_not_allowed_key(self, cache_key): ) def test_set_evict_lru_cache_key_on_reaching_max_size(self): - cache = DefaultCache(CacheConfiguration(max_size=3)) + cache = DefaultCache(CacheConfig(max_size=3)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) cache_key3 = CacheKey(command="GET", redis_keys=("foo2",)) @@ -987,7 +967,7 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self): "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) def test_get_return_correct_value(self, cache_key): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( @@ -1000,12 +980,17 @@ def test_get_return_correct_value(self, cache_key): assert cache.get(wrong_key) is None result = cache.get(cache_key) + assert cache.set( + CacheEntry( + cache_key=cache_key, cache_value=b"new_val", status=CacheEntryStatus.VALID + ) + ) # Make sure that result is immutable. - assert result != cache.get(cache_key) + assert result.cache_value != cache.get(cache_key).cache_value def test_delete_by_cache_keys_removes_associated_entries(self): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) @@ -1038,7 +1023,7 @@ def test_delete_by_cache_keys_removes_associated_entries(self): assert cache.get(cache_key3).cache_value == b"bar2" def test_delete_by_redis_keys_removes_associated_entries(self): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) @@ -1072,7 +1057,7 @@ def test_delete_by_redis_keys_removes_associated_entries(self): assert cache.get(cache_key4).cache_value == b"bar3" def test_flush(self): - cache = DefaultCache(CacheConfiguration(max_size=5)) + cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) @@ -1106,7 +1091,7 @@ def test_type(self): def test_evict_next(self): cache = DefaultCache( - CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) policy = cache.get_eviction_policy() @@ -1129,7 +1114,7 @@ def test_evict_next(self): def test_evict_many(self): cache = DefaultCache( - CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) policy = cache.get_eviction_policy() cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) @@ -1161,7 +1146,7 @@ def test_evict_many(self): def test_touch(self): cache = DefaultCache( - CacheConfiguration(max_size=5, eviction_policy=EvictionPolicy.LRU) + CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) policy = cache.get_eviction_policy() @@ -1209,16 +1194,16 @@ class TestUnitCacheConfiguration: MAX_SIZE = 100 EVICTION_POLICY = EvictionPolicy.LRU - def test_get_max_size(self, cache_conf: CacheConfiguration): + def test_get_max_size(self, cache_conf: CacheConfig): assert self.MAX_SIZE == cache_conf.get_max_size() - def test_get_eviction_policy(self, cache_conf: CacheConfiguration): + def test_get_eviction_policy(self, cache_conf: CacheConfig): assert self.EVICTION_POLICY == cache_conf.get_eviction_policy() - def test_is_exceeds_max_size(self, cache_conf: CacheConfiguration): + def test_is_exceeds_max_size(self, cache_conf: CacheConfig): assert not cache_conf.is_exceeds_max_size(self.MAX_SIZE) assert cache_conf.is_exceeds_max_size(self.MAX_SIZE + 1) - def test_is_allowed_to_cache(self, cache_conf: CacheConfiguration): + def test_is_allowed_to_cache(self, cache_conf: CacheConfig): assert cache_conf.is_allowed_to_cache("GET") assert not cache_conf.is_allowed_to_cache("SET") diff --git a/tests/test_connection.py b/tests/test_connection.py index 0b1f6fb5ad..7a039c1442 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,7 +1,8 @@ import socket import types +from typing import Any from unittest import mock -from unittest.mock import patch +from unittest.mock import patch, call, Mock import pytest import redis @@ -9,7 +10,7 @@ from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff from redis.cache import ( - CacheConfiguration, + CacheConfig, CacheEntry, CacheEntryStatus, CacheInterface, @@ -377,14 +378,14 @@ def test_throws_error_on_cache_enable_in_resp2(self): with pytest.raises( RedisError, match="Client caching is only supported with RESP version 3" ): - ConnectionPool(protocol=2, use_cache=True) + ConnectionPool(protocol=2, cache_config=CacheConfig()) def test_throws_error_on_incorrect_cache_implementation(self): with pytest.raises(ValueError, match="Cache must implement CacheInterface"): - ConnectionPool(protocol=3, use_cache=True, cache='wrong') + ConnectionPool(protocol=3, cache='wrong') def test_returns_custom_cache_implementation(self, mock_cache): - connection_pool = ConnectionPool(protocol=3, use_cache=True, cache=mock_cache) + connection_pool = ConnectionPool(protocol=3, cache=mock_cache) assert mock_cache == connection_pool.cache connection_pool.disconnect() @@ -396,8 +397,7 @@ def test_creates_cache_with_custom_cache_factory( connection_pool = ConnectionPool( protocol=3, - use_cache=True, - cache_config=CacheConfiguration(max_size=5), + cache_config=CacheConfig(max_size=5), cache_factory=mock_cache_factory, ) @@ -406,16 +406,16 @@ def test_creates_cache_with_custom_cache_factory( def test_creates_cache_with_given_configuration(self, mock_cache): connection_pool = ConnectionPool( - protocol=3, use_cache=True, cache_config=CacheConfiguration(max_size=100) + protocol=3, cache_config=CacheConfig(max_size=100) ) assert isinstance(connection_pool.cache, CacheInterface) - assert connection_pool.cache.get_max_size() == 100 + assert connection_pool.cache.get_config().get_max_size() == 100 assert isinstance(connection_pool.cache.get_eviction_policy(), LRUPolicy) connection_pool.disconnect() def test_make_connection_proxy_connection_on_given_cache(self): - connection_pool = ConnectionPool(protocol=3, use_cache=True) + connection_pool = ConnectionPool(protocol=3, cache_config=CacheConfig()) assert isinstance(connection_pool.make_connection(), CacheProxyConnection) connection_pool.disconnect() @@ -423,7 +423,7 @@ def test_make_connection_proxy_connection_on_given_cache(self): class TestUnitCacheProxyConnection: def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): - cache = DefaultCache(CacheConfiguration(max_size=10)) + cache = DefaultCache(CacheConfig(max_size=10)) cache_key = CacheKey(command="GET", redis_keys=("foo",)) cache.set( @@ -442,3 +442,67 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): proxy_connection.disconnect() assert len(cache.get_collection()) == 0 + + def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): + mock_connection.retry = 'mock' + mock_connection.host = "mock" + mock_connection.port = "mock" + + mock_cache.is_cachable.return_value = True + mock_cache.get.side_effect = [ + None, + None, + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID + ), + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID + ), + ] + mock_connection.send_command.return_value = Any + mock_connection.read_response.return_value = b'bar' + mock_connection.can_read.return_value = False + + proxy_connection = CacheProxyConnection(mock_connection, mock_cache) + proxy_connection.send_command(*['GET', 'foo'], **{'keys': ['foo']}) + assert proxy_connection.read_response() == b'bar' + assert proxy_connection.read_response() == b'bar' + + mock_connection.read_response.assert_called_once() + mock_cache.set.assert_has_calls([ + call(CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS + )), + call(CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID + )), + ]) + + mock_cache.get.assert_has_calls([ + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + ]) + + From 4459dd691fe5386c2ce8ac6cc01201fa0ac9c4bf Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 30 Aug 2024 15:08:16 +0300 Subject: [PATCH 52/78] Revert test skip --- tests/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index d16009e270..3d7733478f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -39,7 +39,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -#@skip_if_resp_version(2) +@skip_if_resp_version(2) class TestCache: @pytest.mark.parametrize( "r", From 9fb5aa243b188b10f732ed1ef357167d02180daa Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 2 Sep 2024 11:56:49 +0300 Subject: [PATCH 53/78] Added documentation and codestyle fixes --- docs/resp3_features.rst | 32 +++++++++++++++++ redis/cache.py | 5 ++- redis/client.py | 5 ++- redis/connection.py | 1 - redis/retry.py | 2 -- tests/test_cache.py | 4 ++- tests/test_connection.py | 74 +++++++++++++++++++++------------------- 7 files changed, 80 insertions(+), 43 deletions(-) diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 11c01985a0..3f35250233 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -67,3 +67,35 @@ This means that should you want to perform something, on a given push notificati >> p = r.pubsub(push_handler_func=our_func) In the example above, upon receipt of a push notification, rather than log the message, in the case where specific text occurs, an IOError is raised. This example, highlights how one could start implementing a customized message handler. + +Client-side caching +------------------- + +Client-side caching is a technique used to create high performance services. +It exploits the memory available on application servers, servers that are usually distinct computers compared to the database nodes, to store some subset of the database information directly in the application side. +For more information please check `official Redis documentation `_. +Please notice that this feature only available with RESP3 protocol enabled in sync client only. Supported in standalone, Cluster and Sentinel clients. + +Basic usage: + +Enable caching with default configuration: + +.. code:: python + + >>> import redis + >>> from redis.cache import CacheConfig + >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache_config=CacheConfig()) + +The same interface applies to Redis Cluster and Sentinel. + +Enable caching with custom cache implementation: + +.. code:: python + + >>> import redis + >>> from foo.bar import CacheImpl + >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache=CacheImpl()) + +CacheImpl should implement a `CacheInterface` specified in `redis.cache` package. + +More robust documentation soon will be available at `official Redis documentation `_. diff --git a/redis/cache.py b/redis/cache.py index 6320682e97..53faf8d055 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,9 +1,8 @@ import copy -import time from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum -from typing import Any, Collection, Hashable, List, Optional, Union +from typing import Any, List, Optional, Union class CacheEntryStatus(Enum): @@ -260,7 +259,7 @@ def touch(self, cache_key: CacheKey) -> None: self._assert_cache() if self._cache.get_collection().get(cache_key) is None: - raise ValueError(f"Given entry does not belong to the cache") + raise ValueError("Given entry does not belong to the cache") self._cache.get_collection().move_to_end(cache_key) diff --git a/redis/client.py b/redis/client.py index ec2edfcb35..bf3432e7eb 100755 --- a/redis/client.py +++ b/redis/client.py @@ -319,7 +319,10 @@ def __init__( self.connection_pool = connection_pool - if (cache_config or cache) and self.connection_pool.get_protocol() not in [3, "3"]: + if (cache_config or cache) and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: raise RedisError("Client caching is only supported with RESP version 3") self.connection = None diff --git a/redis/connection.py b/redis/connection.py index 1d13a25c46..97763797f9 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -13,7 +13,6 @@ from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( - CacheConfig, CacheEntry, CacheEntryStatus, CacheFactory, diff --git a/redis/retry.py b/redis/retry.py index 1b0fe113f8..648892ca14 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,12 +1,10 @@ import socket -import time from time import sleep from typing import ( TYPE_CHECKING, Any, Callable, Iterable, - Optional, Tuple, Type, TypeVar, diff --git a/tests/test_cache.py b/tests/test_cache.py index 3d7733478f..221967dc00 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -982,7 +982,9 @@ def test_get_return_correct_value(self, cache_key): result = cache.get(cache_key) assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"new_val", status=CacheEntryStatus.VALID + cache_key=cache_key, + cache_value=b"new_val", + status=CacheEntryStatus.VALID, ) ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7a039c1442..215f74850b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,7 +2,7 @@ import types from typing import Any from unittest import mock -from unittest.mock import patch, call, Mock +from unittest.mock import call, patch import pytest import redis @@ -16,8 +16,6 @@ CacheInterface, CacheKey, DefaultCache, - EvictionPolicy, - EvictionPolicyInterface, LRUPolicy, ) from redis.connection import ( @@ -382,7 +380,7 @@ def test_throws_error_on_cache_enable_in_resp2(self): def test_throws_error_on_incorrect_cache_implementation(self): with pytest.raises(ValueError, match="Cache must implement CacheInterface"): - ConnectionPool(protocol=3, cache='wrong') + ConnectionPool(protocol=3, cache="wrong") def test_returns_custom_cache_implementation(self, mock_cache): connection_pool = ConnectionPool(protocol=3, cache=mock_cache) @@ -444,7 +442,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): assert len(cache.get_collection()) == 0 def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): - mock_connection.retry = 'mock' + mock_connection.retry = "mock" mock_connection.host = "mock" mock_connection.port = "mock" @@ -455,54 +453,60 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, - status=CacheEntryStatus.IN_PROGRESS + status=CacheEntryStatus.IN_PROGRESS, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", - status=CacheEntryStatus.VALID + status=CacheEntryStatus.VALID, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", - status=CacheEntryStatus.VALID + status=CacheEntryStatus.VALID, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", - status=CacheEntryStatus.VALID + status=CacheEntryStatus.VALID, ), ] mock_connection.send_command.return_value = Any - mock_connection.read_response.return_value = b'bar' + mock_connection.read_response.return_value = b"bar" mock_connection.can_read.return_value = False proxy_connection = CacheProxyConnection(mock_connection, mock_cache) - proxy_connection.send_command(*['GET', 'foo'], **{'keys': ['foo']}) - assert proxy_connection.read_response() == b'bar' - assert proxy_connection.read_response() == b'bar' + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + assert proxy_connection.read_response() == b"bar" + assert proxy_connection.read_response() == b"bar" mock_connection.read_response.assert_called_once() - mock_cache.set.assert_has_calls([ - call(CacheEntry( - cache_key=CacheKey(command="GET", redis_keys=("foo",)), - cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, - status=CacheEntryStatus.IN_PROGRESS - )), - call(CacheEntry( - cache_key=CacheKey(command="GET", redis_keys=("foo",)), - cache_value=b"bar", - status=CacheEntryStatus.VALID - )), - ]) - - mock_cache.get.assert_has_calls([ - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - call(CacheKey(command="GET", redis_keys=("foo",))), - ]) - + mock_cache.set.assert_has_calls( + [ + call( + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, + status=CacheEntryStatus.IN_PROGRESS, + ) + ), + call( + CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + ) + ), + ] + ) + mock_cache.get.assert_has_calls( + [ + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + call(CacheKey(command="GET", redis_keys=("foo",))), + ] + ) From 01e405e977461c3f2e5fb113c197f169b6dc5bf9 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 2 Sep 2024 12:36:42 +0300 Subject: [PATCH 54/78] Updated excluded wordlist --- .github/wordlist.txt | 1 + redis/retry.py | 10 +--------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/wordlist.txt b/.github/wordlist.txt index ca2102b825..3ea543748e 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -1,6 +1,7 @@ APM ARGV BFCommands +CacheImpl CFCommands CMSCommands ClusterNode diff --git a/redis/retry.py b/redis/retry.py index 648892ca14..03fd973c4c 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,14 +1,6 @@ import socket from time import sleep -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Tuple, - Type, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, TimeoutError From b5a81332577f944eb0e29de1f701d7a1a031a9a3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 3 Sep 2024 13:42:25 +0300 Subject: [PATCH 55/78] Added health_check thread cancelling in BlockingPool --- redis/connection.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/redis/connection.py b/redis/connection.py index 97763797f9..eb21caec18 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1669,3 +1669,11 @@ def disconnect(self): self._checkpid() for connection in self._connections: connection.disconnect() + + # Send an event to stop scheduled healthcheck execution. + if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): + self._hc_cancel_event.set() + + # Joins healthcheck thread on disconnect. + if self._hc_thread is not None and not self._hc_thread.is_alive(): + self._hc_thread.join() From 4035ce6e1ace28979b27cf0e9d747b8946433b72 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 5 Sep 2024 10:00:41 +0300 Subject: [PATCH 56/78] Revert argument rename, extended documentation --- docs/resp3_features.rst | 6 ++++++ redis/cluster.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 3f35250233..6ad2cb4f48 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -98,4 +98,10 @@ Enable caching with custom cache implementation: CacheImpl should implement a `CacheInterface` specified in `redis.cache` package. +Explicit disconnect + +It's important to call `disconnect()` or `disconnect_connection_pools()` in case of Cluster to properly close the connection to server. +For caching purposes, we're using a separate thread that performs health checks with configurable interval and it relies on +`disconnect()` or `disconnect_connection_pools()` to be called before the shutdown. + More robust documentation soon will be available at `official Redis documentation `_. diff --git a/redis/cluster.py b/redis/cluster.py index a3142028fc..8960cafc7b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1056,8 +1056,8 @@ def _parse_target_nodes(self, target_nodes): ) return nodes - def execute_command(self, *args, **options): - return self._internal_execute_command(*args, **options) + def execute_command(self, *args, **kwargs): + return self._internal_execute_command(*args, **kwargs) def _internal_execute_command(self, *args, **kwargs): """ From 6c47c64425ab55d6e9e101cfbd1a4dee18d7e89f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 9 Sep 2024 15:18:21 +0300 Subject: [PATCH 57/78] Updated NodesManager to create shared cache between all nodes --- redis/cluster.py | 20 +++++++++++++------- redis/connection.py | 25 +++++++++++-------------- tests/test_cache.py | 2 ++ 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 8960cafc7b..f0a7d46e43 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import CacheConfig, CacheInterface +from redis.cache import CacheConfig, CacheInterface, CacheFactoryInterface, CacheFactory from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args @@ -1327,6 +1327,7 @@ def __init__( address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, + cache_factory: Optional[CacheFactoryInterface] = None, **kwargs, ): self.nodes_cache = {} @@ -1339,8 +1340,9 @@ def __init__( self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap - self.cache = cache - self.cache_config = cache_config + self._cache = cache + self._cache_config = cache_config + self._cache_factory = cache_factory self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() @@ -1484,15 +1486,13 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - kwargs.update({"cache": self.cache}) - kwargs.update({"cache_config": self.cache_config}) + kwargs.update({"cache": self._cache}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: r = Redis( host=host, port=port, - cache=self.cache, - cache_config=self.cache_config, + cache=self._cache, **kwargs, ) return r @@ -1624,6 +1624,12 @@ def initialize(self): f"one reachable node: {str(exception)}" ) from exception + if self._cache is None and self._cache_config is not None: + if self._cache_factory is None: + self._cache = CacheFactory(self._cache_config).get_cache() + else: + self._cache = self._cache_factory.get_cache() + # Create Redis connections to all nodes self.create_redis_connections(list(tmp_nodes_cache.values())) diff --git a/redis/connection.py b/redis/connection.py index eb21caec18..8bda1c9f4b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1468,13 +1468,7 @@ def disconnect(self, inuse_connections: bool = True) -> None: for connection in connections: connection.disconnect() - # Send an event to stop scheduled healthcheck execution. - if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): - self._hc_cancel_event.set() - - # Joins healthcheck thread on disconnect. - if self._hc_thread is not None and not self._hc_thread.is_alive(): - self._hc_thread.join() + self.stop_scheduled_healthcheck() def close(self) -> None: """Close the pool, disconnecting all connections""" @@ -1496,6 +1490,15 @@ def run_scheduled_healthcheck(self) -> None: self._perform_health_check, hc_interval, self._hc_cancel_event ) + def stop_scheduled_healthcheck(self) -> None: + # Send an event to stop scheduled healthcheck execution. + if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): + self._hc_cancel_event.set() + + # Joins healthcheck thread on disconnect. + if self._hc_thread is not None and not self._hc_thread.is_alive(): + self._hc_thread.join() + def _perform_health_check(self, done: threading.Event) -> None: self._checkpid() with self._lock: @@ -1670,10 +1673,4 @@ def disconnect(self): for connection in self._connections: connection.disconnect() - # Send an event to stop scheduled healthcheck execution. - if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): - self._hc_cancel_event.set() - - # Joins healthcheck thread on disconnect. - if self._hc_thread is not None and not self._hc_thread.is_alive(): - self._hc_thread.join() + self.stop_scheduled_healthcheck() diff --git a/tests/test_cache.py b/tests/test_cache.py index 221967dc00..6cb27ce92e 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -374,6 +374,8 @@ def test_get_from_cache(self, r): cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value == b"barbar" ) + # Make sure that cache is shared between nodes. + assert cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache() @pytest.mark.parametrize( "r", From c47d4aae9b8abf211497d0dc0be1a9b9d607c0be Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 9 Sep 2024 15:20:17 +0300 Subject: [PATCH 58/78] Codestyle fixes --- redis/cluster.py | 2 +- tests/test_cache.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index f0a7d46e43..bd8ddb9a84 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -9,7 +9,7 @@ from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.cache import CacheConfig, CacheInterface, CacheFactoryInterface, CacheFactory +from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args diff --git a/tests/test_cache.py b/tests/test_cache.py index 6cb27ce92e..e46cc31013 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -375,7 +375,9 @@ def test_get_from_cache(self, r): == b"barbar" ) # Make sure that cache is shared between nodes. - assert cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache() + assert ( + cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache() + ) @pytest.mark.parametrize( "r", From fd361a7774ed4833a20ef72a3bd4b116fb768093 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 9 Sep 2024 18:13:18 +0300 Subject: [PATCH 59/78] Updated docs --- docs/resp3_features.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 6ad2cb4f48..7ca8a83b6d 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -100,8 +100,8 @@ CacheImpl should implement a `CacheInterface` specified in `redis.cache` package Explicit disconnect -It's important to call `disconnect()` or `disconnect_connection_pools()` in case of Cluster to properly close the connection to server. +It's important to call `close()` to properly close the connection to server. For caching purposes, we're using a separate thread that performs health checks with configurable interval and it relies on -`disconnect()` or `disconnect_connection_pools()` to be called before the shutdown. +`close()` to be called before the shutdown. More robust documentation soon will be available at `official Redis documentation `_. From 97ebebf32c85d7ebfb2f620b36fd4eec8e5e1318 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 10 Sep 2024 14:43:00 +0300 Subject: [PATCH 60/78] Added version restrictions --- redis/connection.py | 38 ++++++++++++++++++++++++++++++++------ redis/utils.py | 30 ++++++++++++++++++++++++++++++ tests/test_cache.py | 6 +++++- tests/test_utils.py | 27 +++++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 tests/test_utils.py diff --git a/redis/connection.py b/redis/connection.py index 8bda1c9f4b..2e4bc1d747 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, List, Optional, Type, Union, Dict from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -43,7 +43,7 @@ SSL_AVAILABLE, format_error_message, get_lib_version, - str_if_bytes, + str_if_bytes, compare_versions, ) if HIREDIS_AVAILABLE: @@ -197,6 +197,11 @@ def pack_command(self, *args): def pack_commands(self, commands): pass + @property + @abstractmethod + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -272,6 +277,7 @@ def __init__( self.next_health_check = 0 self.redis_connect_func = redis_connect_func self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size self.set_parser(parser_class) @@ -414,7 +420,7 @@ def on_connect(self): if len(auth_args) == 1: auth_args = ["default", auth_args[0]] self.send_command("HELLO", self.protocol, "AUTH", *auth_args) - response = self.read_response() + self.handshake_metadata = self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" # ) != self.protocol: @@ -445,10 +451,10 @@ def on_connect(self): self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) self.send_command("HELLO", self.protocol) - response = self.read_response() + self.handshake_metadata = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + self.handshake_metadata.get(b"proto") != self.protocol + and self.handshake_metadata.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -649,6 +655,14 @@ def pack_commands(self, commands): def get_protocol(self) -> int or str: return self.protocol + @property + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + return self._handshake_metadata + + @handshake_metadata.setter + def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): + self._handshake_metadata = value + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -731,6 +745,7 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" + MIN_ALLOWED_VERSION = '7.4.0' def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() @@ -759,6 +774,17 @@ def set_parser(self, parser_class): def connect(self): self._conn.connect() + server_ver = self._conn.handshake_metadata.get(b"version", None) + if server_ver is None: + raise ConnectionError("Cannot retrieve information about server version") + + server_ver = server_ver.decode("utf-8") + + if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1: + raise ConnectionError( + "Server version does not satisfies a minimal requirement for client-side caching" + ) + def on_connect(self): self._conn.on_connect() diff --git a/redis/utils.py b/redis/utils.py index a0f31f7ca4..4b3a4647dc 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -153,3 +153,33 @@ def format_error_message(host_error: str, exception: BaseException) -> str: f"Error {exception.args[0]} connecting to {host_error}. " f"{exception.args[1]}." ) + + +def compare_versions(version1: str, version2: str) -> int: + """ + Compare two versions. + + :return: -1 if version1 > version2 + 0 if both versions are equal + 1 if version1 < version2 + """ + + num_versions1 = list(map(int, version1.split("."))) + num_versions2 = list(map(int, version2.split("."))) + + if len(num_versions1) > len(num_versions2): + diff = len(num_versions1) - len(num_versions2) + for _ in range(diff): + num_versions2.append(0) + elif len(num_versions1) < len(num_versions2): + diff = len(num_versions2) - len(num_versions1) + for _ in range(diff): + num_versions1.append(0) + + for i, ver in enumerate(num_versions1): + if num_versions1[i] > num_versions2[i]: + return -1 + elif num_versions1[i] < num_versions2[i]: + return 1 + + return 0 diff --git a/tests/test_cache.py b/tests/test_cache.py index e46cc31013..1a26fa668d 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -13,7 +13,7 @@ LRUPolicy, ) from redis.utils import HIREDIS_AVAILABLE -from tests.conftest import _get_client, skip_if_resp_version +from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt @pytest.fixture() @@ -40,6 +40,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") class TestCache: @pytest.mark.parametrize( "r", @@ -343,6 +344,7 @@ def test_cache_flushed_on_server_flush(self, r): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlycluster @skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") class TestClusterCache: @pytest.mark.parametrize( "r", @@ -605,6 +607,7 @@ def test_cache_flushed_on_server_flush(self, r, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") class TestSentinelCache: @pytest.mark.parametrize( "sentinel_setup", @@ -729,6 +732,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster @skip_if_resp_version(2) +@skip_if_server_version_lt("7.4.0") class TestSSLCache: @pytest.mark.parametrize( "r", diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..764ef5d0a9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,27 @@ +import pytest +from redis.utils import compare_versions + + +@pytest.mark.parametrize( + "version1,version2,expected_res", + [ + ("1.0.0", "0.9.0", -1), + ("1.0.0", "1.0.0", 0), + ("0.9.0", "1.0.0", 1), + ("1.09.0", "1.9.0", 0), + ("1.090.0", "1.9.0", -1), + ("1", "0.9.0", -1), + ("1", "1.0.0", 0), + ], + ids=[ + "version1 > version2", + "version1 == version2", + "version1 < version2", + "version1 == version2 - different minor format", + "version1 > version2 - different minor format", + "version1 > version2 - major version only", + "version1 == version2 - major version only", + ], +) +def test_compare_versions(version1, version2, expected_res): + assert compare_versions(version1, version2) == expected_res From 6387a864e0cd468c39160f570ad3b8fe1286ae62 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 10 Sep 2024 14:45:53 +0300 Subject: [PATCH 61/78] Added missing property getter --- redis/connection.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 2e4bc1d747..8c3ad5a5eb 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -9,7 +9,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Any, Callable, List, Optional, Type, Union, Dict +from typing import Any, Callable, Dict, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -41,9 +41,10 @@ HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, + compare_versions, format_error_message, get_lib_version, - str_if_bytes, compare_versions, + str_if_bytes, ) if HIREDIS_AVAILABLE: @@ -745,7 +746,7 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" - MIN_ALLOWED_VERSION = '7.4.0' + MIN_ALLOWED_VERSION = "7.4.0" def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() @@ -892,6 +893,10 @@ def pack_command(self, *args): def pack_commands(self, commands): return self._conn.pack_commands(commands) + @property + def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: + return self._conn.handshake_metadata + def _connect(self): self._conn._connect() From 097d92eb9949d915a74e3a8c5d4b139c76b33bc4 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 10 Sep 2024 14:57:50 +0300 Subject: [PATCH 62/78] Updated Redis server version --- .github/workflows/integration.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index c4da3bf3aa..b10edf2fb4 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -27,7 +27,7 @@ env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} # this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665 COVERAGE_CORE: sysmon - REDIS_IMAGE: redis:7.4-rc2 + REDIS_IMAGE: redis:latest REDIS_STACK_IMAGE: redis/redis-stack-server:latest jobs: From c2a4edfd16edb76f30b8c15cdb8e28e570d69806 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 10 Sep 2024 15:03:50 +0300 Subject: [PATCH 63/78] Skip on long exception message --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index 8c3ad5a5eb..f104a3bb8e 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -783,7 +783,7 @@ def connect(self): if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1: raise ConnectionError( - "Server version does not satisfies a minimal requirement for client-side caching" + "Server version does not satisfies a minimal requirement for client-side caching" # noqa: E501 ) def on_connect(self): From c11020763626a49422fe668d1d13124748968034 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 12 Sep 2024 11:11:27 +0300 Subject: [PATCH 64/78] Removed keys entry as it's csc specific --- redis/cluster.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/redis/cluster.py b/redis/cluster.py index bd8ddb9a84..fbf5428d40 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2265,6 +2265,8 @@ def _send_cluster_commands( response = [] for c in sorted(stack, key=lambda x: x.position): if c.args[0] in self.cluster_response_callbacks: + # Remove keys entry, it needs only for cache. + c.options.pop("keys", None) c.result = self.cluster_response_callbacks[c.args[0]]( c.result, **c.options ) From ffff100cf649ed0999f85982bb537f7125c23b57 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 12 Sep 2024 13:51:59 +0300 Subject: [PATCH 65/78] Updated exception message for CSC --- redis/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/connection.py b/redis/connection.py index f104a3bb8e..867f72d5a5 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -783,7 +783,7 @@ def connect(self): if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1: raise ConnectionError( - "Server version does not satisfies a minimal requirement for client-side caching" # noqa: E501 + "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 ) def on_connect(self): From 9852b781178ed88d944cf5b335828a7cfc30616f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 12 Sep 2024 17:08:19 +0300 Subject: [PATCH 66/78] Updated condition by adding server name check --- redis/connection.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 867f72d5a5..f1ca69fde0 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -747,6 +747,7 @@ def ensure_string(key): class CacheProxyConnection(ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" MIN_ALLOWED_VERSION = "7.4.0" + DEFAULT_SERVER_NAME = b"redis" def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() @@ -775,13 +776,17 @@ def set_parser(self, parser_class): def connect(self): self._conn.connect() - server_ver = self._conn.handshake_metadata.get(b"version", None) + server_name = self._conn.handshake_metadata.get(b"server") + server_ver = self._conn.handshake_metadata.get(b"version") if server_ver is None: raise ConnectionError("Cannot retrieve information about server version") server_ver = server_ver.decode("utf-8") - if compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1: + if ( + server_name != self.DEFAULT_SERVER_NAME + or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 + ): raise ConnectionError( "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 ) From 97abacde5bbaef1bdf1e72b4527f5d1057188362 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 13 Sep 2024 15:46:12 +0300 Subject: [PATCH 67/78] Added test coverage for decoded responses --- redis/cache.py | 3 +- redis/connection.py | 24 ++++------ redis/utils.py | 9 ++++ tests/conftest.py | 2 + tests/test_cache.py | 111 ++++++++++++++++++++++++++++++-------------- 5 files changed, 98 insertions(+), 51 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index 53faf8d055..61626beef7 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -201,7 +201,8 @@ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: keys_to_delete = [] for redis_key in redis_keys: - redis_key = redis_key.decode() + if isinstance(redis_key, bytes): + redis_key = redis_key.decode() for cache_key in self._cache: if redis_key in cache_key.get_redis_keys(): keys_to_delete.append(cache_key) diff --git a/redis/connection.py b/redis/connection.py index f1ca69fde0..a2d9f4f2c2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -44,7 +44,7 @@ compare_versions, format_error_message, get_lib_version, - str_if_bytes, + str_if_bytes, ensure_string, ) if HIREDIS_AVAILABLE: @@ -735,19 +735,10 @@ def _host_error(self): return f"{self.host}:{self.port}" -def ensure_string(key): - if isinstance(key, bytes): - return key.decode("utf-8") - elif isinstance(key, str): - return key - else: - raise TypeError("Key must be either a string or bytes") - - class CacheProxyConnection(ConnectionInterface): DUMMY_CACHE_VALUE = b"foo" MIN_ALLOWED_VERSION = "7.4.0" - DEFAULT_SERVER_NAME = b"redis" + DEFAULT_SERVER_NAME = "redis" def __init__(self, conn: ConnectionInterface, cache: CacheInterface): self.pid = os.getpid() @@ -776,12 +767,17 @@ def set_parser(self, parser_class): def connect(self): self._conn.connect() - server_name = self._conn.handshake_metadata.get(b"server") - server_ver = self._conn.handshake_metadata.get(b"version") + server_name = self._conn.handshake_metadata.get(b"server", None) + if server_name is None: + server_name = self._conn.handshake_metadata.get("server", None) + server_ver = self._conn.handshake_metadata.get(b"version", None) if server_ver is None: + server_ver = self._conn.handshake_metadata.get("version", None) + if server_ver is None or server_ver is None: raise ConnectionError("Cannot retrieve information about server version") - server_ver = server_ver.decode("utf-8") + server_ver = ensure_string(server_ver) + server_name = ensure_string(server_name) if ( server_name != self.DEFAULT_SERVER_NAME diff --git a/redis/utils.py b/redis/utils.py index 4b3a4647dc..b4e9afb054 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -183,3 +183,12 @@ def compare_versions(version1: str, version2: str) -> int: return 1 return 0 + + +def ensure_string(key): + if isinstance(key, bytes): + return key.decode("utf-8") + elif isinstance(key, str): + return key + else: + raise TypeError("Key must be either a string or bytes") diff --git a/tests/conftest.py b/tests/conftest.py index 0755fd390e..0c98eee4d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -442,6 +442,7 @@ def sentinel_setup(request): cache = request.param.get("cache", None) cache_config = request.param.get("cache_config", None) force_master_ip = request.param.get("force_master_ip", None) + decode_responses = request.param.get("decode_responses", False) sentinel = Sentinel( sentinel_endpoints, force_master_ip=force_master_ip, @@ -449,6 +450,7 @@ def sentinel_setup(request): cache=cache, cache_config=cache_config, protocol=3, + decode_responses=decode_responses, **kwargs, ) yield sentinel diff --git a/tests/test_cache.py b/tests/test_cache.py index 1a26fa668d..e106cdb156 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -24,6 +24,7 @@ def r(request): protocol = request.param.get("protocol", 3) ssl = request.param.get("ssl", False) single_connection_client = request.param.get("single_connection_client", False) + decode_responses = request.param.get("decode_responses", False) with _get_client( redis.Redis, request, @@ -32,6 +33,7 @@ def r(request): single_connection_client=single_connection_client, cache=cache, cache_config=cache_config, + decode_responses=decode_responses, **kwargs, ) as client: yield client @@ -53,8 +55,13 @@ class TestCache: "cache": DefaultCache(CacheConfig(max_size=5)), "single_connection_client": False, }, + { + "cache": DefaultCache(CacheConfig(max_size=5)), + "single_connection_client": False, + "decode_responses": True, + }, ], - ids=["single", "pool"], + ids=["single", "pool", "decoded"], indirect=True, ) @pytest.mark.onlynoncluster @@ -63,20 +70,20 @@ def test_get_from_given_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -90,8 +97,13 @@ def test_get_from_given_cache(self, r, r2): "cache_config": CacheConfig(max_size=128), "single_connection_client": False, }, + { + "cache_config": CacheConfig(max_size=128), + "single_connection_client": False, + "decode_responses": True, + }, ], - ids=["single", "pool"], + ids=["single", "pool", "decoded"], indirect=True, ) @pytest.mark.onlynoncluster @@ -103,20 +115,20 @@ def test_get_from_default_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -351,7 +363,11 @@ class TestClusterCache: [ { "cache": DefaultCache(CacheConfig(max_size=128)), - } + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "decode_responses": True, + }, ], indirect=True, ) @@ -361,20 +377,20 @@ def test_get_from_cache(self, r): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) # Make sure that cache is shared between nodes. assert ( @@ -387,6 +403,10 @@ def test_get_from_cache(self, r): { "cache_config": CacheConfig(max_size=128), }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, ], indirect=True, ) @@ -398,20 +418,20 @@ def test_get_from_custom_cache(self, r, r2): # add key to redis assert r.set("foo", "bar") # get key from redis and save in local cache - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -615,6 +635,11 @@ class TestSentinelCache: { "cache": DefaultCache(CacheConfig(max_size=128)), "force_master_ip": "localhost", + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "force_master_ip": "localhost", + "decode_responses": True, } ], indirect=True, @@ -624,20 +649,20 @@ def test_get_from_cache(self, master): cache = master.get_cache() master.set("foo", "bar") # get key from redis and save in local cache_data - assert master.get("foo") == b"bar" + assert master.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) master.set("foo", "barbar") # get key from redis - assert master.get("foo") == b"barbar" + assert master.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -646,6 +671,10 @@ def test_get_from_cache(self, master): { "cache_config": CacheConfig(max_size=128), }, + { + "cache_config": CacheConfig(max_size=128), + "decode_responses": True, + }, ], indirect=True, ) @@ -656,20 +685,20 @@ def test_get_from_default_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -731,7 +760,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +#@skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") class TestSSLCache: @pytest.mark.parametrize( @@ -740,6 +769,11 @@ class TestSSLCache: { "cache": DefaultCache(CacheConfig(max_size=128)), "ssl": True, + }, + { + "cache": DefaultCache(CacheConfig(max_size=128)), + "ssl": True, + "decode_responses": True, } ], indirect=True, @@ -750,11 +784,11 @@ def test_get_from_cache(self, r, r2, cache): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) assert r2.set("foo", "barbar") @@ -762,11 +796,11 @@ def test_get_from_cache(self, r, r2, cache): # between data appears in socket buffer time.sleep(0.1) # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( @@ -776,6 +810,11 @@ def test_get_from_cache(self, r, r2, cache): "cache_config": CacheConfig(max_size=128), "ssl": True, }, + { + "cache_config": CacheConfig(max_size=128), + "ssl": True, + "decode_responses": True, + }, ], indirect=True, ) @@ -786,11 +825,11 @@ def test_get_from_custom_cache(self, r, r2): # add key to redis r.set("foo", "bar") # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" + assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" + in [b"bar", "bar"] ) # change key in redis (cause invalidation) r2.set("foo", "barbar") @@ -798,11 +837,11 @@ def test_get_from_custom_cache(self, r, r2): # between data appears in socket buffer time.sleep(0.1) # Retrieves a new value from server and cache_data it - assert r.get("foo") == b"barbar" + assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached assert ( cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"barbar" + in [b"barbar", "barbar"] ) @pytest.mark.parametrize( From 20f3851e6634518afe5900a859bb39e08decdd6c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 13 Sep 2024 15:48:02 +0300 Subject: [PATCH 68/78] Codestyle changes --- redis/connection.py | 3 +- tests/test_cache.py | 134 ++++++++++++++++++++++---------------------- 2 files changed, 69 insertions(+), 68 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index a2d9f4f2c2..d8af15e366 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -42,9 +42,10 @@ HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, compare_versions, + ensure_string, format_error_message, get_lib_version, - str_if_bytes, ensure_string, + str_if_bytes, ) if HIREDIS_AVAILABLE: diff --git a/tests/test_cache.py b/tests/test_cache.py index e106cdb156..725c09aa1f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -72,19 +72,19 @@ def test_get_from_given_cache(self, r, r2): # get key from redis and save in local cache assert r.get("foo") in [b"bar", "bar"] # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", @@ -117,19 +117,19 @@ def test_get_from_default_cache(self, r, r2): # get key from redis and save in local cache assert r.get("foo") in [b"bar", "bar"] # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", @@ -379,19 +379,19 @@ def test_get_from_cache(self, r): # get key from redis and save in local cache assert r.get("foo") in [b"bar", "bar"] # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] # Make sure that cache is shared between nodes. assert ( cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache() @@ -420,19 +420,19 @@ def test_get_from_custom_cache(self, r, r2): # get key from redis and save in local cache assert r.get("foo") in [b"bar", "bar"] # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", @@ -640,7 +640,7 @@ class TestSentinelCache: "cache": DefaultCache(CacheConfig(max_size=128)), "force_master_ip": "localhost", "decode_responses": True, - } + }, ], indirect=True, ) @@ -651,19 +651,19 @@ def test_get_from_cache(self, master): # get key from redis and save in local cache_data assert master.get("foo") in [b"bar", "bar"] # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) master.set("foo", "barbar") # get key from redis assert master.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", @@ -687,19 +687,19 @@ def test_get_from_default_cache(self, r, r2): # get key from redis and save in local cache_data assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") # Retrieves a new value from server and cache_data it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "sentinel_setup", @@ -760,7 +760,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -#@skip_if_resp_version(2) +# @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") class TestSSLCache: @pytest.mark.parametrize( @@ -774,7 +774,7 @@ class TestSSLCache: "cache": DefaultCache(CacheConfig(max_size=128)), "ssl": True, "decode_responses": True, - } + }, ], indirect=True, ) @@ -786,10 +786,10 @@ def test_get_from_cache(self, r, r2, cache): # get key from redis and save in local cache_data assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) assert r2.set("foo", "barbar") # Timeout needed for SSL connection because there's timeout @@ -798,10 +798,10 @@ def test_get_from_cache(self, r, r2, cache): # Retrieves a new value from server and cache_data it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", @@ -827,10 +827,10 @@ def test_get_from_custom_cache(self, r, r2): # get key from redis and save in local cache_data assert r.get("foo") in [b"bar", "bar"] # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"bar", "bar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"bar", + "bar", + ] # change key in redis (cause invalidation) r2.set("foo", "barbar") # Timeout needed for SSL connection because there's timeout @@ -839,10 +839,10 @@ def test_get_from_custom_cache(self, r, r2): # Retrieves a new value from server and cache_data it assert r.get("foo") in [b"barbar", "barbar"] # Make sure that new value was cached - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - in [b"barbar", "barbar"] - ) + assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [ + b"barbar", + "barbar", + ] @pytest.mark.parametrize( "r", From 2d8cf27e43c84d29ac789af6c145838698b3769f Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:14:20 +0300 Subject: [PATCH 69/78] Removed background healthcheck, use connection reference approach instead --- docs/resp3_features.rst | 6 -- redis/cache.py | 23 ++++---- redis/connection.py | 77 +++++++++----------------- redis/scheduler.py | 60 -------------------- redis/sentinel.py | 1 - tests/test_cache.py | 117 +-------------------------------------- tests/test_connection.py | 32 +++++++++++ tests/test_scheduler.py | 39 ------------- 8 files changed, 70 insertions(+), 285 deletions(-) delete mode 100644 redis/scheduler.py delete mode 100644 tests/test_scheduler.py diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 7ca8a83b6d..3f35250233 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -98,10 +98,4 @@ Enable caching with custom cache implementation: CacheImpl should implement a `CacheInterface` specified in `redis.cache` package. -Explicit disconnect - -It's important to call `close()` to properly close the connection to server. -For caching purposes, we're using a separate thread that performs health checks with configurable interval and it relies on -`close()` to be called before the shutdown. - More robust documentation soon will be available at `official Redis documentation `_. diff --git a/redis/cache.py b/redis/cache.py index 61626beef7..f448a98d3c 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,4 +1,5 @@ import copy +import weakref from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum @@ -32,14 +33,21 @@ def __eq__(self, other): class CacheEntry: def __init__( - self, cache_key: CacheKey, cache_value: bytes, status: CacheEntryStatus + self, + cache_key: CacheKey, + cache_value: bytes, + status: CacheEntryStatus, + connection_ref, ): self.cache_key = cache_key self.cache_value = cache_value self.status = status + self.connection_ref = connection_ref def __hash__(self): - return hash((self.cache_key, self.cache_value, self.status)) + return hash( + (self.cache_key, self.cache_value, self.status, self.connection_ref) + ) def __eq__(self, other): return hash(self) == hash(other) @@ -86,10 +94,6 @@ def get_max_size(self) -> int: def get_eviction_policy(self): pass - @abstractmethod - def get_health_check_interval(self) -> float: - pass - @abstractmethod def is_exceeds_max_size(self, count: int) -> bool: pass @@ -182,7 +186,7 @@ def get(self, key: CacheKey) -> Union[CacheEntry, None]: return None self._eviction_policy.touch(key) - return copy.deepcopy(entry) + return entry def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: response = [] @@ -360,12 +364,10 @@ def __init__( max_size: int = DEFAULT_MAX_SIZE, cache_class: Any = DEFAULT_CACHE_CLASS, eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - health_check_interval: float = 2.0, ): self._cache_class = cache_class self._max_size = max_size self._eviction_policy = eviction_policy - self._health_check_interval = health_check_interval def get_cache_class(self): return self._cache_class @@ -376,9 +378,6 @@ def get_max_size(self) -> int: def get_eviction_policy(self) -> EvictionPolicy: return self._eviction_policy - def get_health_check_interval(self) -> float: - return self._health_check_interval - def is_exceeds_max_size(self, count: int) -> bool: return count > self._max_size diff --git a/redis/connection.py b/redis/connection.py index d8af15e366..6aae2101c2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -35,7 +35,6 @@ TimeoutError, ) from .retry import Retry -from .scheduler import Scheduler from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, @@ -741,12 +740,18 @@ class CacheProxyConnection(ConnectionInterface): MIN_ALLOWED_VERSION = "7.4.0" DEFAULT_SERVER_NAME = "redis" - def __init__(self, conn: ConnectionInterface, cache: CacheInterface): + def __init__( + self, + conn: ConnectionInterface, + cache: CacheInterface, + pool_lock: threading.Lock, + ): self.pid = os.getpid() self._conn = conn self.retry = self._conn.retry self.host = self._conn.host self.port = self._conn.port + self._pool_lock = pool_lock self._cache = cache self._cache_lock = threading.Lock() self._current_command_cache_key = None @@ -824,9 +829,17 @@ def send_command(self, *args, **kwargs): ) with self._cache_lock: - # If current command reply already cached - # prevent sending data over socket. + # We have to trigger invalidation processing in case if + # it was cached by another connection to avoid + # queueing invalidations in stale connections. if self._cache.get(self._current_command_cache_key): + entry = self._cache.get(self._current_command_cache_key) + + if entry.connection_ref != self._conn: + with self._pool_lock: + while entry.connection_ref.can_read(): + entry.connection_ref.read_response(push_request=True) + return # Set temporary entry value to prevent @@ -836,6 +849,7 @@ def send_command(self, *args, **kwargs): cache_key=self._current_command_cache_key, cache_value=self.DUMMY_CACHE_VALUE, status=CacheEntryStatus.IN_PROGRESS, + connection_ref=self._conn, ) ) @@ -857,7 +871,9 @@ def read_response( and self._cache.get(self._current_command_cache_key).status != CacheEntryStatus.IN_PROGRESS ): - return self._cache.get(self._current_command_cache_key).cache_value + return copy.deepcopy( + self._cache.get(self._current_command_cache_key).cache_value + ) response = self._conn.read_response( disable_decoding=disable_decoding, @@ -879,13 +895,9 @@ def read_response( # Cache only responses that still valid # and wasn't invalidated by another connection in meantime. if cache_entry is not None: - self._cache.set( - CacheEntry( - cache_key=self._current_command_cache_key, - cache_value=response, - status=CacheEntryStatus.VALID, - ) - ) + cache_entry.status = CacheEntryStatus.VALID + cache_entry.cache_value = response + self._cache.set(cache_entry) return response @@ -1284,9 +1296,6 @@ def __init__( self.max_connections = max_connections self.cache = None self._cache_factory = cache_factory - self._scheduler = None - self._hc_cancel_event = None - self._hc_thread = None if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): if connection_kwargs.get("protocol") not in [3, "3"]: @@ -1307,8 +1316,6 @@ def __init__( self.connection_kwargs.get("cache_config") ).get_cache() - self._scheduler = Scheduler() - connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) @@ -1322,7 +1329,6 @@ def __init__( # release the lock. self._fork_lock = threading.Lock() self.reset() - self.run_scheduled_healthcheck() def __repr__(self) -> (str, str): return ( @@ -1452,7 +1458,7 @@ def make_connection(self) -> "ConnectionInterface": if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache + self.connection_class(**self.connection_kwargs), self.cache, self._lock ) return self.connection_class(**self.connection_kwargs) @@ -1501,8 +1507,6 @@ def disconnect(self, inuse_connections: bool = True) -> None: for connection in connections: connection.disconnect() - self.stop_scheduled_healthcheck() - def close(self) -> None: """Close the pool, disconnecting all connections""" self.disconnect() @@ -1514,33 +1518,6 @@ def set_retry(self, retry: "Retry") -> None: for conn in self._in_use_connections: conn.retry = retry - def run_scheduled_healthcheck(self) -> None: - # Run scheduled healthcheck to avoid stale invalidations in idle connections. - if self.cache is not None and self._scheduler is not None: - self._hc_cancel_event = threading.Event() - hc_interval = self.cache.get_config().get_health_check_interval() - self._hc_thread = self._scheduler.run_with_interval( - self._perform_health_check, hc_interval, self._hc_cancel_event - ) - - def stop_scheduled_healthcheck(self) -> None: - # Send an event to stop scheduled healthcheck execution. - if self._hc_cancel_event is not None and not self._hc_cancel_event.is_set(): - self._hc_cancel_event.set() - - # Joins healthcheck thread on disconnect. - if self._hc_thread is not None and not self._hc_thread.is_alive(): - self._hc_thread.join() - - def _perform_health_check(self, done: threading.Event) -> None: - self._checkpid() - with self._lock: - while self._available_connections: - conn = self._available_connections.pop() - conn.send_command("PING") - conn.read_response() - done.set() - class BlockingConnectionPool(ConnectionPool): """ @@ -1620,7 +1597,7 @@ def make_connection(self): "Make a fresh connection." if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache + self.connection_class(**self.connection_kwargs), self.cache, self._lock ) else: connection = self.connection_class(**self.connection_kwargs) @@ -1705,5 +1682,3 @@ def disconnect(self): self._checkpid() for connection in self._connections: connection.disconnect() - - self.stop_scheduled_healthcheck() diff --git a/redis/scheduler.py b/redis/scheduler.py deleted file mode 100644 index 9bcfc740a0..0000000000 --- a/redis/scheduler.py +++ /dev/null @@ -1,60 +0,0 @@ -import threading -import time -from typing import Callable - - -class Scheduler: - - def __init__(self, polling_period: float = 0.1): - """ - :param polling_period: Period between polling operations. - Needs to detect when new job has to be scheduled. - """ - self.polling_period = polling_period - - def run_with_interval( - self, - func: Callable, - interval: float, - cancel: threading.Event, - args: tuple = (), - ) -> threading.Thread: - """ - Run scheduled execution with given interval - in a separate thread until cancel event won't be set. - """ - done = threading.Event() - thread = threading.Thread( - target=self._run_timer, args=(func, interval, (done, *args), done, cancel) - ) - thread.start() - return thread - - def _get_timer( - self, func: Callable, interval: float, args: tuple - ) -> threading.Timer: - timer = threading.Timer(interval=interval, function=func, args=args) - return timer - - def _run_timer( - self, - func: Callable, - interval: float, - args: tuple, - done: threading.Event, - cancel: threading.Event, - ): - timer = self._get_timer(func, interval, args) - timer.start() - - while not cancel.is_set(): - if done.is_set(): - done.clear() - timer.join() - timer = self._get_timer(func, interval, args) - timer.start() - else: - time.sleep(self.polling_period) - - timer.cancel() - timer.join() diff --git a/redis/sentinel.py b/redis/sentinel.py index 17cc926a98..01e210794c 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -115,7 +115,6 @@ def get_master_address(self): connection_pool = self.connection_pool_ref() if connection_pool is not None: connection_pool.disconnect(inuse_connections=False) - connection_pool.run_scheduled_healthcheck() return master_address def rotate_slaves(self): diff --git a/tests/test_cache.py b/tests/test_cache.py index 725c09aa1f..4fef9e3d41 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -131,35 +131,6 @@ def test_get_from_default_cache(self, r, r2): "barbar", ] - @pytest.mark.parametrize( - "r", - [ - { - "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), - "single_connection_client": False, - }, - ], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache(self, r): - cache = r.get_cache() - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" - ) - # change key in redis (cause invalidation) - r.set("foo", "barbar") - # Wait for health check - time.sleep(1.0) - # Make sure that value was invalidated - assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - @pytest.mark.parametrize( "r", [ @@ -434,34 +405,6 @@ def test_get_from_custom_cache(self, r, r2): "barbar", ] - @pytest.mark.parametrize( - "r", - [ - { - "cache_config": CacheConfig(max_size=128), - }, - ], - indirect=True, - ) - @pytest.mark.onlycluster - def test_health_check_invalidate_cache(self, r, r2): - cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache - assert r.get("foo") == b"bar" - # get key from local cache - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" - ) - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # Wait for health check - time.sleep(2) - # Make sure that value was invalidated - assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - @pytest.mark.parametrize( "r", [ @@ -701,35 +644,6 @@ def test_get_from_default_cache(self, r, r2): "barbar", ] - @pytest.mark.parametrize( - "sentinel_setup", - [ - { - "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), - "force_master_ip": "localhost", - } - ], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache(self, master, cache): - cache = master.get_cache() - # add key to redis - master.set("foo", "bar") - # get key from redis and save in local cache_data - assert master.get("foo") == b"bar" - # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" - ) - # change key in redis (cause invalidation) - master.set("foo", "barbar") - # Wait for health check - time.sleep(1.0) - # Make sure that value was invalidated - assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - @pytest.mark.parametrize( "sentinel_setup", [ @@ -760,7 +674,7 @@ def test_cache_clears_on_disconnect(self, master, cache): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -# @skip_if_resp_version(2) +@skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") class TestSSLCache: @pytest.mark.parametrize( @@ -844,35 +758,6 @@ def test_get_from_custom_cache(self, r, r2): "barbar", ] - @pytest.mark.parametrize( - "r", - [ - { - "cache_config": CacheConfig(max_size=128, health_check_interval=1.0), - "ssl": True, - } - ], - indirect=True, - ) - @pytest.mark.onlynoncluster - def test_health_check_invalidate_cache(self, r, r2): - cache = r.get_cache() - # add key to redis - r.set("foo", "bar") - # get key from redis and save in local cache_data - assert r.get("foo") == b"bar" - # get key from local cache_data - assert ( - cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value - == b"bar" - ) - # change key in redis (cause invalidation) - r2.set("foo", "barbar") - # Wait for health check - time.sleep(1.0) - # Make sure that value was invalidated - assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - @pytest.mark.parametrize( "r", [ diff --git a/tests/test_connection.py b/tests/test_connection.py index 215f74850b..b2749b72e4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,8 @@ +import copy import socket +import threading import types +import weakref from typing import Any from unittest import mock from unittest.mock import call, patch @@ -510,3 +513,32 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): call(CacheKey(command="GET", redis_keys=("foo",))), ] ) + + def test_triggers_invalidation_processing_on_another_connection( + self, mock_cache, mock_connection + ): + mock_connection.retry = "mock" + mock_connection.host = "mock" + mock_connection.port = "mock" + + another_conn = copy.deepcopy(mock_connection) + another_conn.can_read.side_effect = [True, False] + another_conn.read_response.return_value = None + cache_entry = CacheEntry( + cache_key=CacheKey(command="GET", redis_keys=("foo",)), + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=another_conn, + ) + mock_cache.is_cachable.return_value = True + mock_cache.get.return_value = cache_entry + mock_connection.can_read.return_value = False + + proxy_connection = CacheProxyConnection( + mock_connection, mock_cache, threading.Lock() + ) + proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) + + assert proxy_connection.read_response() == b"bar" + assert another_conn.can_read.call_count == 2 + another_conn.read_response.assert_called_once() diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py deleted file mode 100644 index 8ccb2125c4..0000000000 --- a/tests/test_scheduler.py +++ /dev/null @@ -1,39 +0,0 @@ -import threading -import time - -import pytest -from redis.scheduler import Scheduler - - -class TestScheduler: - @pytest.mark.parametrize( - "polling_period,interval,expected_count", - [ - (0.001, 0.1, (8, 9)), - (0.1, 0.2, (3, 4)), - (0.1, 2, (0, 0)), - ], - ids=[ - "small polling period (0.001s)", - "large polling period (0.1s)", - "interval larger than timeout - no execution", - ], - ) - def test_run_with_interval(self, polling_period, interval, expected_count): - scheduler = Scheduler(polling_period=polling_period) - cancel_event = threading.Event() - counter = 0 - - def callback(done: threading.Event): - nonlocal counter - counter += 1 - done.set() - - scheduler.run_with_interval( - func=callback, interval=interval, cancel=cancel_event - ) - time.sleep(1) - cancel_event.set() - cancel_event.wait() - # Due to flacky nature of test case, provides at least 2 possible values. - assert counter in expected_count From 1f6876a7766559e55b4cd2a7aa5613aec2b3eb79 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:15:56 +0300 Subject: [PATCH 70/78] Removed unused imports --- redis/cache.py | 2 -- tests/test_connection.py | 1 - 2 files changed, 3 deletions(-) diff --git a/redis/cache.py b/redis/cache.py index f448a98d3c..0c8f80fd3d 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,5 +1,3 @@ -import copy -import weakref from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum diff --git a/tests/test_connection.py b/tests/test_connection.py index b2749b72e4..b4cfebcae1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,7 +2,6 @@ import socket import threading import types -import weakref from typing import Any from unittest import mock from unittest.mock import call, patch From 30717aaf5bbf13c33129e96b2cf612e076ae3410 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:28:59 +0300 Subject: [PATCH 71/78] Fixed broken tests --- tests/test_cache.py | 102 ++++++++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 37 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 4fef9e3d41..3555f26e46 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -811,12 +811,13 @@ def test_get_size(self): @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) - def test_set_non_existing_cache_key(self, cache_key): + def test_set_non_existing_cache_key(self, cache_key, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.get(cache_key).cache_value == b"val" @@ -824,12 +825,13 @@ def test_set_non_existing_cache_key(self, cache_key): @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) - def test_set_updates_existing_cache_key(self, cache_key): + def test_set_updates_existing_cache_key(self, cache_key, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.get(cache_key).cache_value == b"val" @@ -839,6 +841,7 @@ def test_set_updates_existing_cache_key(self, cache_key): cache_key=cache_key, cache_value=b"new_val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get(cache_key).cache_value == b"new_val" @@ -846,16 +849,17 @@ def test_set_updates_existing_cache_key(self, cache_key): @pytest.mark.parametrize( "cache_key", [{"command": "HRANDFIELD", "redis_keys": ("bar",)}], indirect=True ) - def test_set_does_not_store_not_allowed_key(self, cache_key): + def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) assert not cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) - def test_set_evict_lru_cache_key_on_reaching_max_size(self): + def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): cache = DefaultCache(CacheConfig(max_size=3)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("foo1",)) @@ -864,17 +868,20 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -887,7 +894,8 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self): cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) assert cache.set( CacheEntry( - cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID + cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -898,12 +906,13 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self): @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True ) - def test_get_return_correct_value(self, cache_key): + def test_get_return_correct_value(self, cache_key, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID + cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.get(cache_key).cache_value == b"val" @@ -917,13 +926,14 @@ def test_get_return_correct_value(self, cache_key): cache_key=cache_key, cache_value=b"new_val", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) # Make sure that result is immutable. assert result.cache_value != cache.get(cache_key).cache_value - def test_delete_by_cache_keys_removes_associated_entries(self): + def test_delete_by_cache_keys_removes_associated_entries(self, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) @@ -934,17 +944,20 @@ def test_delete_by_cache_keys_removes_associated_entries(self): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -956,7 +969,7 @@ def test_delete_by_cache_keys_removes_associated_entries(self): assert len(cache.get_collection()) == 1 assert cache.get(cache_key3).cache_value == b"bar2" - def test_delete_by_redis_keys_removes_associated_entries(self): + def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) @@ -967,22 +980,26 @@ def test_delete_by_redis_keys_removes_associated_entries(self): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID + cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -990,7 +1007,7 @@ def test_delete_by_redis_keys_removes_associated_entries(self): assert len(cache.get_collection()) == 1 assert cache.get(cache_key4).cache_value == b"bar3" - def test_flush(self): + def test_flush(self, mock_connection): cache = DefaultCache(CacheConfig(max_size=5)) cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) @@ -1000,17 +1017,20 @@ def test_flush(self): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID + cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -1023,7 +1043,7 @@ def test_type(self): policy = LRUPolicy() assert policy.type == EvictionPolicyType.time_based - def test_evict_next(self): + def test_evict_next(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) @@ -1034,19 +1054,21 @@ def test_evict_next(self): assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert policy.evict_next() == cache_key1 assert cache.get(cache_key1) is None - def test_evict_many(self): + def test_evict_many(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) @@ -1057,17 +1079,20 @@ def test_evict_many(self): assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"baz", status=CacheEntryStatus.VALID + cache_key=cache_key3, cache_value=b"baz", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) @@ -1078,7 +1103,7 @@ def test_evict_many(self): with pytest.raises(ValueError, match="Evictions count is above cache size"): policy.evict_many(99) - def test_touch(self): + def test_touch(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) @@ -1089,19 +1114,22 @@ def test_touch(self): cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.get_collection().popitem(last=True)[0] == cache_key2 cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID + cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) From 98bf72d5a30114af4bb39bc96ed91b6da7d4edb3 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:33:45 +0300 Subject: [PATCH 72/78] Codestyle changes --- tests/test_cache.py | 158 +++++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 53 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 3555f26e46..0d0be72efd 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -816,8 +816,10 @@ def test_set_non_existing_cache_key(self, cache_key, mock_connection): assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get(cache_key).cache_value == b"val" @@ -830,8 +832,10 @@ def test_set_updates_existing_cache_key(self, cache_key, mock_connection): assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get(cache_key).cache_value == b"val" @@ -854,8 +858,10 @@ def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection): assert not cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -868,20 +874,26 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -894,8 +906,10 @@ def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection): cache_key4 = CacheKey(command="GET", redis_keys=("foo3",)) assert cache.set( CacheEntry( - cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key4, + cache_value=b"bar3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -911,8 +925,10 @@ def test_get_return_correct_value(self, cache_key, mock_connection): assert cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"val", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key, + cache_value=b"val", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get(cache_key).cache_value == b"val" @@ -926,7 +942,7 @@ def test_get_return_correct_value(self, cache_key, mock_connection): cache_key=cache_key, cache_value=b"new_val", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + connection_ref=mock_connection, ) ) @@ -944,20 +960,26 @@ def test_delete_by_cache_keys_removes_associated_entries(self, mock_connection): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -980,26 +1002,34 @@ def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key4, cache_value=b"bar3", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key4, + cache_value=b"bar3", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -1017,20 +1047,26 @@ def test_flush(self, mock_connection): # Set 3 different keys assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"bar1", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"bar1", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"bar2", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key3, + cache_value=b"bar2", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -1054,14 +1090,18 @@ def test_evict_next(self, mock_connection): assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -1079,20 +1119,26 @@ def test_evict_many(self, mock_connection): assert cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.set( CacheEntry( - cache_key=cache_key3, cache_value=b"baz", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key3, + cache_value=b"baz", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) @@ -1114,22 +1160,28 @@ def test_touch(self, mock_connection): cache.set( CacheEntry( - cache_key=cache_key1, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key1, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get_collection().popitem(last=True)[0] == cache_key2 cache.set( CacheEntry( - cache_key=cache_key2, cache_value=b"foo", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key2, + cache_value=b"foo", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) From cdf85046da55c2c3e0e3b8f065a6969f37e0d127 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:36:39 +0300 Subject: [PATCH 73/78] Fixed additional broken tests --- tests/test_connection.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index b4cfebcae1..eb1533eca5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -428,7 +428,8 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"bar", status=CacheEntryStatus.VALID + cache_key=cache_key, cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ) assert cache.get(cache_key).cache_value == b"bar" @@ -438,7 +439,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.host = "mock" mock_connection.port = "mock" - proxy_connection = CacheProxyConnection(mock_connection, cache) + proxy_connection = CacheProxyConnection(mock_connection, cache, threading.Lock()) proxy_connection.disconnect() assert len(cache.get_collection()) == 0 @@ -456,28 +457,32 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ), ] mock_connection.send_command.return_value = Any mock_connection.read_response.return_value = b"bar" mock_connection.can_read.return_value = False - proxy_connection = CacheProxyConnection(mock_connection, mock_cache) + proxy_connection = CacheProxyConnection(mock_connection, mock_cache, threading.Lock()) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" assert proxy_connection.read_response() == b"bar" @@ -490,6 +495,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, status=CacheEntryStatus.IN_PROGRESS, + connection_ref=mock_connection ) ), call( @@ -497,6 +503,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, + connection_ref=mock_connection ) ), ] From 7340aad021ffd484fb8e5b90089ff1890ad3e3dc Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:38:20 +0300 Subject: [PATCH 74/78] Codestyle changes --- tests/test_connection.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index eb1533eca5..c79ee7b735 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -428,8 +428,10 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): cache.set( CacheEntry( - cache_key=cache_key, cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + cache_key=cache_key, + cache_value=b"bar", + status=CacheEntryStatus.VALID, + connection_ref=mock_connection, ) ) assert cache.get(cache_key).cache_value == b"bar" @@ -439,7 +441,9 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.host = "mock" mock_connection.port = "mock" - proxy_connection = CacheProxyConnection(mock_connection, cache, threading.Lock()) + proxy_connection = CacheProxyConnection( + mock_connection, cache, threading.Lock() + ) proxy_connection.disconnect() assert len(cache.get_collection()) == 0 @@ -457,32 +461,34 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, status=CacheEntryStatus.IN_PROGRESS, - connection_ref=mock_connection + connection_ref=mock_connection, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + connection_ref=mock_connection, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + connection_ref=mock_connection, ), CacheEntry( cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + connection_ref=mock_connection, ), ] mock_connection.send_command.return_value = Any mock_connection.read_response.return_value = b"bar" mock_connection.can_read.return_value = False - proxy_connection = CacheProxyConnection(mock_connection, mock_cache, threading.Lock()) + proxy_connection = CacheProxyConnection( + mock_connection, mock_cache, threading.Lock() + ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" assert proxy_connection.read_response() == b"bar" @@ -495,7 +501,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE, status=CacheEntryStatus.IN_PROGRESS, - connection_ref=mock_connection + connection_ref=mock_connection, ) ), call( @@ -503,7 +509,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): cache_key=CacheKey(command="GET", redis_keys=("foo",)), cache_value=b"bar", status=CacheEntryStatus.VALID, - connection_ref=mock_connection + connection_ref=mock_connection, ) ), ] From c36219f8ca95979779df3f2225b994b1828eb885 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 15:48:30 +0300 Subject: [PATCH 75/78] Increased timer to avoid flackiness --- tests/test_asyncio/test_hash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_asyncio/test_hash.py b/tests/test_asyncio/test_hash.py index e31ea7eaf3..8d94799fbb 100644 --- a/tests/test_asyncio/test_hash.py +++ b/tests/test_asyncio/test_hash.py @@ -177,7 +177,7 @@ async def test_hexpireat_multiple_fields(r): ) exp_time = int((datetime.now() + timedelta(seconds=1)).timestamp()) assert await r.hexpireat("test:hash", exp_time, "field1", "field2") == [1, 1] - await asyncio.sleep(1.1) + await asyncio.sleep(1.5) assert await r.hexists("test:hash", "field1") is False assert await r.hexists("test:hash", "field2") is False assert await r.hexists("test:hash", "field3") is True From 6d63a59ca735d55102064a74ae18fa74b815f41e Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 16:10:45 +0300 Subject: [PATCH 76/78] Restrict tests cause of PyPy --- tests/test_connection.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_connection.py b/tests/test_connection.py index c79ee7b735..0a5b9c15d9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,5 @@ import copy +import platform import socket import threading import types @@ -448,6 +449,9 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): assert len(cache.get_collection()) == 0 + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="Pypy doesn't support side_effect" + ) def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.retry = "mock" mock_connection.host = "mock" @@ -526,6 +530,9 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): ] ) + @pytest.mark.skipif( + platform.python_implementation() == "PyPy", reason="Pypy doesn't support side_effect" + ) def test_triggers_invalidation_processing_on_another_connection( self, mock_cache, mock_connection ): From 18a4d3d2d8dd4f86399b06450acfcd0588c69d10 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 19 Sep 2024 16:12:08 +0300 Subject: [PATCH 77/78] Codestyle changes --- tests/test_connection.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 0a5b9c15d9..13e5c95a24 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -450,7 +450,8 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): assert len(cache.get_collection()) == 0 @pytest.mark.skipif( - platform.python_implementation() == "PyPy", reason="Pypy doesn't support side_effect" + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", ) def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.retry = "mock" @@ -531,7 +532,8 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): ) @pytest.mark.skipif( - platform.python_implementation() == "PyPy", reason="Pypy doesn't support side_effect" + platform.python_implementation() == "PyPy", + reason="Pypy doesn't support side_effect", ) def test_triggers_invalidation_processing_on_another_connection( self, mock_cache, mock_connection From 94ddab185de1f1a67656cb3ca052d294196f7d75 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 27 Sep 2024 13:47:39 +0300 Subject: [PATCH 78/78] Updated docs, convert getters function to properties, added dataclasses --- docs/resp3_features.rst | 4 ++-- redis/cache.py | 52 ++++++++++++++++++++-------------------- tests/test_cache.py | 48 ++++++++++++++++++------------------- tests/test_connection.py | 6 ++--- 4 files changed, 55 insertions(+), 55 deletions(-) diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst index 3f35250233..326495b775 100644 --- a/docs/resp3_features.rst +++ b/docs/resp3_features.rst @@ -72,7 +72,7 @@ Client-side caching ------------------- Client-side caching is a technique used to create high performance services. -It exploits the memory available on application servers, servers that are usually distinct computers compared to the database nodes, to store some subset of the database information directly in the application side. +It utilizes the memory on application servers, typically separate from the database nodes, to cache a subset of the data directly on the application side. For more information please check `official Redis documentation `_. Please notice that this feature only available with RESP3 protocol enabled in sync client only. Supported in standalone, Cluster and Sentinel clients. @@ -98,4 +98,4 @@ Enable caching with custom cache implementation: CacheImpl should implement a `CacheInterface` specified in `redis.cache` package. -More robust documentation soon will be available at `official Redis documentation `_. +More comprehensive documentation soon will be available at `official Redis documentation `_. diff --git a/redis/cache.py b/redis/cache.py index 0c8f80fd3d..9971edd256 100644 --- a/redis/cache.py +++ b/redis/cache.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections import OrderedDict +from dataclasses import dataclass from enum import Enum from typing import Any, List, Optional, Union @@ -14,19 +15,10 @@ class EvictionPolicyType(Enum): frequency_based = "frequency_based" +@dataclass(frozen=True) class CacheKey: - def __init__(self, command: str, redis_keys: tuple): - self.command = command - self.redis_keys = redis_keys - - def get_redis_keys(self) -> tuple: - return self.redis_keys - - def __hash__(self): - return hash((self.command, self.redis_keys)) - - def __eq__(self, other): - return hash(self) == hash(other) + command: str + redis_keys: tuple class CacheEntry: @@ -102,20 +94,24 @@ def is_allowed_to_cache(self, command: str) -> bool: class CacheInterface(ABC): + @property @abstractmethod - def get_collection(self) -> OrderedDict: + def collection(self) -> OrderedDict: pass + @property @abstractmethod - def get_config(self) -> CacheConfigurationInterface: + def config(self) -> CacheConfigurationInterface: pass + @property @abstractmethod - def get_eviction_policy(self) -> EvictionPolicyInterface: + def eviction_policy(self) -> EvictionPolicyInterface: pass + @property @abstractmethod - def get_size(self) -> int: + def size(self) -> int: pass @abstractmethod @@ -153,16 +149,20 @@ def __init__( self._eviction_policy = self._cache_config.get_eviction_policy().value() self._eviction_policy.cache = self - def get_collection(self) -> OrderedDict: + @property + def collection(self) -> OrderedDict: return self._cache - def get_config(self) -> CacheConfigurationInterface: + @property + def config(self) -> CacheConfigurationInterface: return self._cache_config - def get_eviction_policy(self) -> EvictionPolicyInterface: + @property + def eviction_policy(self) -> EvictionPolicyInterface: return self._eviction_policy - def get_size(self) -> int: + @property + def size(self) -> int: return len(self._cache) def set(self, entry: CacheEntry) -> bool: @@ -206,7 +206,7 @@ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: if isinstance(redis_key, bytes): redis_key = redis_key.decode() for cache_key in self._cache: - if redis_key in cache_key.get_redis_keys(): + if redis_key in cache_key.redis_keys: keys_to_delete.append(cache_key) response.append(True) @@ -242,18 +242,18 @@ def type(self) -> EvictionPolicyType: def evict_next(self) -> CacheKey: self._assert_cache() - popped_entry = self._cache.get_collection().popitem(last=False) + popped_entry = self._cache.collection.popitem(last=False) return popped_entry[0] def evict_many(self, count: int) -> List[CacheKey]: self._assert_cache() - if count > len(self._cache.get_collection()): + if count > len(self._cache.collection): raise ValueError("Evictions count is above cache size") popped_keys = [] for _ in range(count): - popped_entry = self._cache.get_collection().popitem(last=False) + popped_entry = self._cache.collection.popitem(last=False) popped_keys.append(popped_entry[0]) return popped_keys @@ -261,10 +261,10 @@ def evict_many(self, count: int) -> List[CacheKey]: def touch(self, cache_key: CacheKey) -> None: self._assert_cache() - if self._cache.get_collection().get(cache_key) is None: + if self._cache.collection.get(cache_key) is None: raise ValueError("Given entry does not belong to the cache") - self._cache.get_collection().move_to_end(cache_key) + self._cache.collection.move_to_end(cache_key) def _assert_cache(self): if self.cache is None or not isinstance(self.cache, CacheInterface): diff --git a/tests/test_cache.py b/tests/test_cache.py index 0d0be72efd..1803646094 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -41,7 +41,7 @@ def r(request): @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @pytest.mark.onlynoncluster -@skip_if_resp_version(2) +# @skip_if_resp_version(2) @skip_if_server_version_lt("7.4.0") class TestCache: @pytest.mark.parametrize( @@ -109,8 +109,8 @@ def test_get_from_given_cache(self, r, r2): @pytest.mark.onlynoncluster def test_get_from_default_cache(self, r, r2): cache = r.get_cache() - assert isinstance(cache.get_eviction_policy(), LRUPolicy) - assert cache.get_config().get_max_size() == 128 + assert isinstance(cache.eviction_policy, LRUPolicy) + assert cache.config.get_max_size() == 128 # add key to redis r.set("foo", "bar") @@ -161,7 +161,7 @@ def test_cache_clears_on_disconnect(self, r, cache): # Force disconnection r.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.parametrize( "r", @@ -207,7 +207,7 @@ def test_cache_lru_eviction(self, r, cache): assert r.get("foo4") == b"bar4" # the first key is not in the local cache anymore assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None - assert cache.get_size() == 3 + assert cache.size == 3 @pytest.mark.parametrize( "r", @@ -321,7 +321,7 @@ def test_cache_flushed_on_server_flush(self, r): # Flush server and trying to access cached entry assert r.flushall() assert r.get("foo") is None - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -383,8 +383,8 @@ def test_get_from_cache(self, r): ) def test_get_from_custom_cache(self, r, r2): cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache() - assert isinstance(cache.get_eviction_policy(), LRUPolicy) - assert cache.get_config().get_max_size() == 128 + assert isinstance(cache.eviction_policy, LRUPolicy) + assert cache.config.get_max_size() == 128 # add key to redis assert r.set("foo", "bar") @@ -431,7 +431,7 @@ def test_cache_clears_on_disconnect(self, r, r2): 12000 ).redis_connection.connection_pool.get_connection("_").disconnect() # Make sure cache is empty - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.parametrize( "r", @@ -564,7 +564,7 @@ def test_cache_flushed_on_server_flush(self, r, cache): # Flush server and trying to access cached entry assert r.flushall() assert r.get("foo{slot}") is None - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -623,7 +623,7 @@ def test_get_from_cache(self, master): ) def test_get_from_default_cache(self, r, r2): cache = r.get_cache() - assert isinstance(cache.get_eviction_policy(), LRUPolicy) + assert isinstance(cache.eviction_policy, LRUPolicy) # add key to redis r.set("foo", "bar") @@ -669,7 +669,7 @@ def test_cache_clears_on_disconnect(self, master, cache): # Force disconnection master.connection_pool.get_connection("_").disconnect() # Make sure cache_data is empty - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -734,7 +734,7 @@ def test_get_from_cache(self, r, r2, cache): ) def test_get_from_custom_cache(self, r, r2): cache = r.get_cache() - assert isinstance(cache.get_eviction_policy(), LRUPolicy) + assert isinstance(cache.eviction_policy, LRUPolicy) # add key to redis r.set("foo", "bar") @@ -798,15 +798,15 @@ def test_cache_invalidate_all_related_responses(self, r): class TestUnitDefaultCache: def test_get_eviction_policy(self): cache = DefaultCache(CacheConfig(max_size=5)) - assert isinstance(cache.get_eviction_policy(), LRUPolicy) + assert isinstance(cache.eviction_policy, LRUPolicy) def test_get_max_size(self): cache = DefaultCache(CacheConfig(max_size=5)) - assert cache.get_config().get_max_size() == 5 + assert cache.config.get_max_size() == 5 def test_get_size(self): cache = DefaultCache(CacheConfig(max_size=5)) - assert cache.get_size() == 0 + assert cache.size == 0 @pytest.mark.parametrize( "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True @@ -988,7 +988,7 @@ def test_delete_by_cache_keys_removes_associated_entries(self, mock_connection): True, False, ] - assert len(cache.get_collection()) == 1 + assert len(cache.collection) == 1 assert cache.get(cache_key3).cache_value == b"bar2" def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection): @@ -1034,7 +1034,7 @@ def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection): ) assert cache.delete_by_redis_keys([b"foo", b"foo1"]) == [True, True, True] - assert len(cache.get_collection()) == 1 + assert len(cache.collection) == 1 assert cache.get(cache_key4).cache_value == b"bar3" def test_flush(self, mock_connection): @@ -1071,7 +1071,7 @@ def test_flush(self, mock_connection): ) assert cache.flush() == 3 - assert len(cache.get_collection()) == 0 + assert len(cache.collection) == 0 class TestUnitLRUPolicy: @@ -1083,7 +1083,7 @@ def test_evict_next(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) - policy = cache.get_eviction_policy() + policy = cache.eviction_policy cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) @@ -1112,7 +1112,7 @@ def test_evict_many(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) - policy = cache.get_eviction_policy() + policy = cache.eviction_policy cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) cache_key3 = CacheKey(command="GET", redis_keys=("baz",)) @@ -1153,7 +1153,7 @@ def test_touch(self, mock_connection): cache = DefaultCache( CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU) ) - policy = cache.get_eviction_policy() + policy = cache.eviction_policy cache_key1 = CacheKey(command="GET", redis_keys=("foo",)) cache_key2 = CacheKey(command="GET", redis_keys=("bar",)) @@ -1175,7 +1175,7 @@ def test_touch(self, mock_connection): ) ) - assert cache.get_collection().popitem(last=True)[0] == cache_key2 + assert cache.collection.popitem(last=True)[0] == cache_key2 cache.set( CacheEntry( cache_key=cache_key2, @@ -1186,7 +1186,7 @@ def test_touch(self, mock_connection): ) policy.touch(cache_key1) - assert cache.get_collection().popitem(last=True)[0] == cache_key1 + assert cache.collection.popitem(last=True)[0] == cache_key1 def test_throws_error_on_invalid_cache(self): policy = LRUPolicy() diff --git a/tests/test_connection.py b/tests/test_connection.py index 13e5c95a24..a58703e3b5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -411,8 +411,8 @@ def test_creates_cache_with_given_configuration(self, mock_cache): ) assert isinstance(connection_pool.cache, CacheInterface) - assert connection_pool.cache.get_config().get_max_size() == 100 - assert isinstance(connection_pool.cache.get_eviction_policy(), LRUPolicy) + assert connection_pool.cache.config.get_max_size() == 100 + assert isinstance(connection_pool.cache.eviction_policy, LRUPolicy) connection_pool.disconnect() def test_make_connection_proxy_connection_on_given_cache(self): @@ -447,7 +447,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): ) proxy_connection.disconnect() - assert len(cache.get_collection()) == 0 + assert len(cache.collection) == 0 @pytest.mark.skipif( platform.python_implementation() == "PyPy",