@@ -725,13 +725,14 @@ def ensure_string(key):
725
725
726
726
727
727
class CacheProxyConnection (ConnectionInterface ):
728
- def __init__ (self , conn : ConnectionInterface , cache : Cache , conf : CacheConfiguration ):
728
+ def __init__ (self , conn : ConnectionInterface , cache : Cache , conf : CacheConfiguration , cache_lock : threading . Lock ):
729
729
self .pid = os .getpid ()
730
730
self ._conn = conn
731
731
self .retry = self ._conn .retry
732
732
self .host = self ._conn .host
733
733
self .port = self ._conn .port
734
734
self ._cache = cache
735
+ self ._cache_lock = cache_lock
735
736
self ._conf = conf
736
737
self ._current_command_hash = None
737
738
self ._current_command_keys = None
@@ -758,7 +759,8 @@ def on_connect(self):
758
759
self ._conn .on_connect ()
759
760
760
761
def disconnect (self , * args ):
761
- self ._cache .clear ()
762
+ with self ._cache_lock :
763
+ self ._cache .clear ()
762
764
self ._conn .disconnect (* args )
763
765
764
766
def check_health (self ):
@@ -789,12 +791,13 @@ def send_command(self, *args, **kwargs):
789
791
if not isinstance (self ._current_command_keys , list ):
790
792
raise TypeError ("Cache keys must be a list." )
791
793
792
- # If current command reply already cached prevent sending data over socket.
793
- if self ._cache .get (self ._current_command_hash ):
794
- return
794
+ with self ._cache_lock :
795
+ # If current command reply already cached prevent sending data over socket.
796
+ if self ._cache .get (self ._current_command_hash ):
797
+ return
795
798
796
- # Set temporary entry as a status to prevent race condition from another connection.
797
- self ._cache [self ._current_command_hash ] = "caching-in-progress"
799
+ # Set temporary entry as a status to prevent race condition from another connection.
800
+ self ._cache [self ._current_command_hash ] = "caching-in-progress"
798
801
799
802
# Send command over socket only if it's allowed read-only command that not yet cached.
800
803
self ._conn .send_command (* args , ** kwargs )
@@ -803,40 +806,42 @@ def can_read(self, timeout=0):
803
806
return self ._conn .can_read (timeout )
804
807
805
808
def read_response (self , disable_decoding = False , * , disconnect_on_error = True , push_request = False ):
806
- # Check if command response exists in a cache and it's not in progress.
807
- if (
808
- self ._current_command_hash in self ._cache
809
- and self ._cache [self ._current_command_hash ] != "caching-in-progress"
810
- ):
811
- return self ._cache [self ._current_command_hash ]
809
+ with self ._cache_lock :
810
+ # Check if command response exists in a cache and it's not in progress.
811
+ if (
812
+ self ._current_command_hash in self ._cache
813
+ and self ._cache [self ._current_command_hash ] != "caching-in-progress"
814
+ ):
815
+ return self ._cache [self ._current_command_hash ]
812
816
813
817
response = self ._conn .read_response (
814
818
disable_decoding = disable_decoding ,
815
819
disconnect_on_error = disconnect_on_error ,
816
820
push_request = push_request
817
821
)
818
822
819
- # If response is None prevent from caching and remove temporary cache entry.
820
- if response is None :
821
- self ._cache .pop (self ._current_command_hash )
822
- return response
823
- # Prevent not-allowed command from caching.
824
- elif self ._current_command_hash is None :
825
- return response
826
-
827
- # Create separate mapping for keys or add current response to associated keys.
828
- for key in self ._current_command_keys :
829
- if key in self ._keys_mapping :
830
- if self ._current_command_hash not in self ._keys_mapping [key ]:
831
- self ._keys_mapping [key ].append (self ._current_command_hash )
832
- else :
833
- self ._keys_mapping [key ] = [self ._current_command_hash ]
823
+ with self ._cache_lock :
824
+ # If response is None prevent from caching and remove temporary cache entry.
825
+ if response is None :
826
+ self ._cache .pop (self ._current_command_hash )
827
+ return response
828
+ # Prevent not-allowed command from caching.
829
+ elif self ._current_command_hash is None :
830
+ return response
831
+
832
+ # Create separate mapping for keys or add current response to associated keys.
833
+ for key in self ._current_command_keys :
834
+ if key in self ._keys_mapping :
835
+ if self ._current_command_hash not in self ._keys_mapping [key ]:
836
+ self ._keys_mapping [key ].append (self ._current_command_hash )
837
+ else :
838
+ self ._keys_mapping [key ] = [self ._current_command_hash ]
834
839
835
- cache_entry = self ._cache .get (self ._current_command_hash , None )
840
+ cache_entry = self ._cache .get (self ._current_command_hash , None )
836
841
837
- # Cache only responses that still valid and wasn't invalidated by another connection in meantime
838
- if cache_entry is not None :
839
- self ._cache [self ._current_command_hash ] = response
842
+ # Cache only responses that still valid and wasn't invalidated by another connection in meantime
843
+ if cache_entry is not None :
844
+ self ._cache [self ._current_command_hash ] = response
840
845
841
846
return response
842
847
@@ -864,18 +869,19 @@ def _process_pending_invalidations(self):
864
869
def _on_invalidation_callback (
865
870
self , data : List [Union [str , Optional [List [str ]]]]
866
871
):
867
- # Flush cache when DB flushed on server-side
868
- if data [1 ] is None :
869
- self ._cache .clear ()
870
- else :
871
- for key in data [1 ]:
872
- normalized_key = ensure_string (key )
873
- if normalized_key in self ._keys_mapping :
874
- # Make sure that all command responses associated with this key will be deleted
875
- for cache_key in self ._keys_mapping [normalized_key ]:
876
- self ._cache .pop (cache_key )
877
- # Removes key from mapping cache
878
- self ._keys_mapping .pop (normalized_key )
872
+ with self ._cache_lock :
873
+ # Flush cache when DB flushed on server-side
874
+ if data [1 ] is None :
875
+ self ._cache .clear ()
876
+ else :
877
+ for key in data [1 ]:
878
+ normalized_key = ensure_string (key )
879
+ if normalized_key in self ._keys_mapping :
880
+ # Make sure that all command responses associated with this key will be deleted
881
+ for cache_key in self ._keys_mapping [normalized_key ]:
882
+ self ._cache .pop (cache_key )
883
+ # Removes key from mapping cache
884
+ self ._keys_mapping .pop (normalized_key )
879
885
880
886
881
887
class SSLConnection (Connection ):
@@ -1238,13 +1244,15 @@ def __init__(
1238
1244
self .max_connections = max_connections
1239
1245
self ._cache = None
1240
1246
self ._cache_conf = None
1247
+ self .cache_lock = None
1241
1248
self .scheduler = None
1242
1249
1243
1250
if connection_kwargs .get ("use_cache" ):
1244
1251
if connection_kwargs .get ("protocol" ) not in [3 , "3" ]:
1245
1252
raise RedisError ("Client caching is only supported with RESP version 3" )
1246
1253
1247
1254
self ._cache_conf = CacheConfiguration (** self .connection_kwargs )
1255
+ self ._cache_lock = threading .Lock ()
1248
1256
1249
1257
cache = self .connection_kwargs .get ("cache" )
1250
1258
if cache is not None :
@@ -1399,7 +1407,12 @@ def make_connection(self) -> "ConnectionInterface":
1399
1407
self ._created_connections += 1
1400
1408
1401
1409
if self ._cache is not None and self ._cache_conf is not None :
1402
- return CacheProxyConnection (self .connection_class (** self .connection_kwargs ), self ._cache , self ._cache_conf )
1410
+ return CacheProxyConnection (
1411
+ self .connection_class (** self .connection_kwargs ),
1412
+ self ._cache ,
1413
+ self ._cache_conf ,
1414
+ self ._cache_lock
1415
+ )
1403
1416
1404
1417
return self .connection_class (** self .connection_kwargs )
1405
1418
@@ -1547,7 +1560,8 @@ def make_connection(self):
1547
1560
connection = CacheProxyConnection (
1548
1561
self .connection_class (** self .connection_kwargs ),
1549
1562
self ._cache ,
1550
- self ._cache_conf
1563
+ self ._cache_conf ,
1564
+ self ._cache_lock
1551
1565
)
1552
1566
else :
1553
1567
connection = self .connection_class (** self .connection_kwargs )
0 commit comments