Skip to content

Commit e77cb60

Browse files
committed
Added locking when accesing cache object
1 parent 33f656e commit e77cb60

File tree

2 files changed

+139
-73
lines changed

2 files changed

+139
-73
lines changed

redis/connection.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -725,13 +725,14 @@ def ensure_string(key):
725725

726726

727727
class CacheProxyConnection(ConnectionInterface):
728-
def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration):
728+
def __init__(self, conn: ConnectionInterface, cache: Cache, conf: CacheConfiguration, cache_lock: threading.Lock):
729729
self.pid = os.getpid()
730730
self._conn = conn
731731
self.retry = self._conn.retry
732732
self.host = self._conn.host
733733
self.port = self._conn.port
734734
self._cache = cache
735+
self._cache_lock = cache_lock
735736
self._conf = conf
736737
self._current_command_hash = None
737738
self._current_command_keys = None
@@ -758,7 +759,8 @@ def on_connect(self):
758759
self._conn.on_connect()
759760

760761
def disconnect(self, *args):
761-
self._cache.clear()
762+
with self._cache_lock:
763+
self._cache.clear()
762764
self._conn.disconnect(*args)
763765

764766
def check_health(self):
@@ -789,12 +791,13 @@ def send_command(self, *args, **kwargs):
789791
if not isinstance(self._current_command_keys, list):
790792
raise TypeError("Cache keys must be a list.")
791793

792-
# If current command reply already cached prevent sending data over socket.
793-
if self._cache.get(self._current_command_hash):
794-
return
794+
with self._cache_lock:
795+
# If current command reply already cached prevent sending data over socket.
796+
if self._cache.get(self._current_command_hash):
797+
return
795798

796-
# Set temporary entry as a status to prevent race condition from another connection.
797-
self._cache[self._current_command_hash] = "caching-in-progress"
799+
# Set temporary entry as a status to prevent race condition from another connection.
800+
self._cache[self._current_command_hash] = "caching-in-progress"
798801

799802
# Send command over socket only if it's allowed read-only command that not yet cached.
800803
self._conn.send_command(*args, **kwargs)
@@ -803,40 +806,42 @@ def can_read(self, timeout=0):
803806
return self._conn.can_read(timeout)
804807

805808
def read_response(self, disable_decoding=False, *, disconnect_on_error=True, push_request=False):
806-
# Check if command response exists in a cache and it's not in progress.
807-
if (
808-
self._current_command_hash in self._cache
809-
and self._cache[self._current_command_hash] != "caching-in-progress"
810-
):
811-
return self._cache[self._current_command_hash]
809+
with self._cache_lock:
810+
# Check if command response exists in a cache and it's not in progress.
811+
if (
812+
self._current_command_hash in self._cache
813+
and self._cache[self._current_command_hash] != "caching-in-progress"
814+
):
815+
return self._cache[self._current_command_hash]
812816

813817
response = self._conn.read_response(
814818
disable_decoding=disable_decoding,
815819
disconnect_on_error=disconnect_on_error,
816820
push_request=push_request
817821
)
818822

819-
# If response is None prevent from caching and remove temporary cache entry.
820-
if response is None:
821-
self._cache.pop(self._current_command_hash)
822-
return response
823-
# Prevent not-allowed command from caching.
824-
elif self._current_command_hash is None:
825-
return response
826-
827-
# Create separate mapping for keys or add current response to associated keys.
828-
for key in self._current_command_keys:
829-
if key in self._keys_mapping:
830-
if self._current_command_hash not in self._keys_mapping[key]:
831-
self._keys_mapping[key].append(self._current_command_hash)
832-
else:
833-
self._keys_mapping[key] = [self._current_command_hash]
823+
with self._cache_lock:
824+
# If response is None prevent from caching and remove temporary cache entry.
825+
if response is None:
826+
self._cache.pop(self._current_command_hash)
827+
return response
828+
# Prevent not-allowed command from caching.
829+
elif self._current_command_hash is None:
830+
return response
831+
832+
# Create separate mapping for keys or add current response to associated keys.
833+
for key in self._current_command_keys:
834+
if key in self._keys_mapping:
835+
if self._current_command_hash not in self._keys_mapping[key]:
836+
self._keys_mapping[key].append(self._current_command_hash)
837+
else:
838+
self._keys_mapping[key] = [self._current_command_hash]
834839

835-
cache_entry = self._cache.get(self._current_command_hash, None)
840+
cache_entry = self._cache.get(self._current_command_hash, None)
836841

837-
# Cache only responses that still valid and wasn't invalidated by another connection in meantime
838-
if cache_entry is not None:
839-
self._cache[self._current_command_hash] = response
842+
# Cache only responses that still valid and wasn't invalidated by another connection in meantime
843+
if cache_entry is not None:
844+
self._cache[self._current_command_hash] = response
840845

841846
return response
842847

@@ -864,18 +869,19 @@ def _process_pending_invalidations(self):
864869
def _on_invalidation_callback(
865870
self, data: List[Union[str, Optional[List[str]]]]
866871
):
867-
# Flush cache when DB flushed on server-side
868-
if data[1] is None:
869-
self._cache.clear()
870-
else:
871-
for key in data[1]:
872-
normalized_key = ensure_string(key)
873-
if normalized_key in self._keys_mapping:
874-
# Make sure that all command responses associated with this key will be deleted
875-
for cache_key in self._keys_mapping[normalized_key]:
876-
self._cache.pop(cache_key)
877-
# Removes key from mapping cache
878-
self._keys_mapping.pop(normalized_key)
872+
with self._cache_lock:
873+
# Flush cache when DB flushed on server-side
874+
if data[1] is None:
875+
self._cache.clear()
876+
else:
877+
for key in data[1]:
878+
normalized_key = ensure_string(key)
879+
if normalized_key in self._keys_mapping:
880+
# Make sure that all command responses associated with this key will be deleted
881+
for cache_key in self._keys_mapping[normalized_key]:
882+
self._cache.pop(cache_key)
883+
# Removes key from mapping cache
884+
self._keys_mapping.pop(normalized_key)
879885

880886

881887
class SSLConnection(Connection):
@@ -1238,13 +1244,15 @@ def __init__(
12381244
self.max_connections = max_connections
12391245
self._cache = None
12401246
self._cache_conf = None
1247+
self.cache_lock = None
12411248
self.scheduler = None
12421249

12431250
if connection_kwargs.get("use_cache"):
12441251
if connection_kwargs.get("protocol") not in [3, "3"]:
12451252
raise RedisError("Client caching is only supported with RESP version 3")
12461253

12471254
self._cache_conf = CacheConfiguration(**self.connection_kwargs)
1255+
self._cache_lock = threading.Lock()
12481256

12491257
cache = self.connection_kwargs.get("cache")
12501258
if cache is not None:
@@ -1399,7 +1407,12 @@ def make_connection(self) -> "ConnectionInterface":
13991407
self._created_connections += 1
14001408

14011409
if self._cache is not None and self._cache_conf is not None:
1402-
return CacheProxyConnection(self.connection_class(**self.connection_kwargs), self._cache, self._cache_conf)
1410+
return CacheProxyConnection(
1411+
self.connection_class(**self.connection_kwargs),
1412+
self._cache,
1413+
self._cache_conf,
1414+
self._cache_lock
1415+
)
14031416

14041417
return self.connection_class(**self.connection_kwargs)
14051418

@@ -1547,7 +1560,8 @@ def make_connection(self):
15471560
connection = CacheProxyConnection(
15481561
self.connection_class(**self.connection_kwargs),
15491562
self._cache,
1550-
self._cache_conf
1563+
self._cache_conf,
1564+
self._cache_lock
15511565
)
15521566
else:
15531567
connection = self.connection_class(**self.connection_kwargs)

0 commit comments

Comments
 (0)