Skip to content

Commit 01c76e8

Browse files
committed
get json() from modules working in asyncio.cluster + AsyncClusterPipeline (for json())
1 parent 12f95de commit 01c76e8

File tree

4 files changed

+456
-12
lines changed

4 files changed

+456
-12
lines changed

redis/asyncio/cluster.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Generator,
1111
List,
1212
Mapping,
13+
MutableMapping,
1314
Optional,
1415
Type,
1516
TypeVar,
@@ -25,7 +26,7 @@
2526
parse_url,
2627
)
2728
from redis.asyncio.parser import CommandsParser
28-
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
29+
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis, CaseInsensitiveDict
2930
from redis.cluster import (
3031
PIPELINE_BLOCKED_COMMANDS,
3132
PRIMARY,
@@ -37,7 +38,7 @@
3738
get_node_name,
3839
parse_cluster_slots,
3940
)
40-
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
41+
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands, AsyncRedisModuleCommands
4142
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
4243
from redis.exceptions import (
4344
AskError,
@@ -78,7 +79,7 @@ class ClusterParser(DefaultParser):
7879
)
7980

8081

81-
class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
82+
class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands, AsyncRedisModuleCommands):
8283
"""
8384
Create a new RedisCluster client.
8485
@@ -152,6 +153,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand
152153
- none of the `host`/`port` & `startup_nodes` were provided
153154
154155
"""
156+
response_callbacks: MutableMapping[Union[str, bytes], ResponseCallbackT]
155157

156158
@classmethod
157159
def from_url(cls, url: str, **kwargs: Any) -> "RedisCluster":
@@ -298,7 +300,10 @@ def __init__(
298300
# Call our on_connect function to configure READONLY mode
299301
kwargs["redis_connect_func"] = self.on_connect
300302

301-
kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy()
303+
kwargs["cluster_response_callbacks"] = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS)
304+
self.cluster_response_callbacks = CaseInsensitiveDict(
305+
self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS
306+
)
302307
self.connection_kwargs = kwargs
303308

304309
if startup_nodes:
@@ -324,7 +329,7 @@ def __init__(
324329
self.commands_parser = CommandsParser()
325330
self.node_flags = self.__class__.NODE_FLAGS.copy()
326331
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
327-
self.response_callbacks = kwargs["response_callbacks"]
332+
self.cluster_response_callbacks = kwargs["cluster_response_callbacks"]
328333
self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy()
329334
self.result_callbacks[
330335
"CLUSTER SLOTS"
@@ -479,7 +484,7 @@ def get_connection_kwargs(self) -> Dict[str, Optional[Any]]:
479484

480485
def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None:
481486
"""Set a custom response callback."""
482-
self.response_callbacks[command] = callback
487+
self.cluster_response_callbacks[command] = callback
483488

484489
async def _determine_nodes(
485490
self, command: str, *args: Any, node_flag: Optional[str] = None
@@ -809,7 +814,7 @@ def __init__(
809814
self.max_connections = max_connections
810815
self.connection_class = connection_class
811816
self.connection_kwargs = connection_kwargs
812-
self.response_callbacks = connection_kwargs.pop("response_callbacks", {})
817+
self.response_callbacks = connection_kwargs.pop("cluster_response_callbacks", {})
813818

814819
self._connections: List[Connection] = []
815820
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
@@ -1206,7 +1211,7 @@ async def close(self, attr: str = "nodes_cache") -> None:
12061211
)
12071212

12081213

1209-
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
1214+
class ClusterPipeline(RedisCluster):
12101215
"""
12111216
Create a new ClusterPipeline object.
12121217
@@ -1245,9 +1250,21 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm
12451250

12461251
__slots__ = ("_command_stack", "_client")
12471252

1248-
def __init__(self, client: RedisCluster) -> None:
1253+
def __init__(
1254+
self,
1255+
client: RedisCluster,
1256+
nodes_manager=None,
1257+
commands_parser=None,
1258+
result_callbacks=None,
1259+
startup_nodes=None,
1260+
read_from_replicas=False,
1261+
cluster_error_retry_attempts=5,
1262+
reinitialize_steps=10,
1263+
lock=None,
1264+
**kwargs,
1265+
) -> None:
12491266
self._client = client
1250-
1267+
self.cluster_response_callbacks = self.RESPONSE_CALLBACKS
12511268
self._command_stack: List["PipelineCommand"] = []
12521269

12531270
async def initialize(self) -> "ClusterPipeline":

redis/commands/json/__init__.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from json import JSONDecodeError, JSONDecoder, JSONEncoder
22

33
import redis
4-
4+
from redis.asyncio.cluster import RedisCluster as AsyncRedisCluster
5+
from ...asyncio.client import Pipeline as AsyncioPipeline
6+
from ...asyncio.cluster import ClusterPipeline as AsyncioClusterPipeline
57
from ..helpers import nativestr
6-
from .commands import JSONCommands
8+
from .commands import JSONCommands, AsyncJSONCommands
79
from .decoders import bulk_of_jsons, decode_list
810

911

@@ -131,3 +133,75 @@ class ClusterPipeline(JSONCommands, redis.cluster.ClusterPipeline):
131133

132134
class Pipeline(JSONCommands, redis.client.Pipeline):
133135
"""Pipeline for the module."""
136+
137+
138+
class AsyncJSON(JSON, AsyncJSONCommands):
139+
"""
140+
Create a client for talking to json.
141+
:param decoder:
142+
:type json.JSONDecoder: An instance of json.JSONDecoder
143+
:param encoder:
144+
:type json.JSONEncoder: An instance of json.JSONEncoder
145+
"""
146+
147+
def _decode(self, obj):
148+
if obj is None:
149+
return obj
150+
151+
try:
152+
x = self.__decoder__.decode(obj)
153+
if x is None:
154+
raise TypeError
155+
return x
156+
except TypeError:
157+
try:
158+
return self.__decoder__.decode(obj.decode())
159+
except AttributeError:
160+
return decode_list(obj)
161+
except (AttributeError, JSONDecodeError):
162+
return decode_list(obj)
163+
164+
def _encode(self, obj):
165+
return self.__encoder__.encode(obj)
166+
167+
def pipeline(self, transaction=True, shard_hint=None):
168+
"""Creates a pipeline for the JSON module, that can be used for executing
169+
JSON commands, as well as classic core commands.
170+
Usage example:
171+
r = redis.Redis()
172+
pipe = r.json().pipeline()
173+
pipe.jsonset('foo', '.', {'hello!': 'world'})
174+
pipe.jsonget('foo')
175+
pipe.jsonget('notakey')
176+
"""
177+
if isinstance(self.client, AsyncRedisCluster):
178+
p = AsyncioClusterPipeline(
179+
nodes_manager=self.client.nodes_manager,
180+
commands_parser=self.client.commands_parser,
181+
startup_nodes=self.client.nodes_manager.startup_nodes,
182+
result_callbacks=self.client.result_callbacks,
183+
cluster_response_callbacks=self.client.cluster_response_callbacks,
184+
cluster_error_retry_attempts=self.client.cluster_error_retry_attempts,
185+
read_from_replicas=self.client.read_from_replicas,
186+
reinitialize_steps=self.client.reinitialize_steps,
187+
lock=self.client._lock,
188+
)
189+
190+
else:
191+
p = AsyncioPipeline(
192+
connection_pool=self.client.connection_pool,
193+
response_callbacks=self.MODULE_CALLBACKS,
194+
transaction=transaction,
195+
shard_hint=shard_hint,
196+
)
197+
p._encode = self._encode
198+
p._decode = self._decode
199+
return p
200+
201+
202+
class AsyncClusterPipeline(AsyncJSONCommands, AsyncioClusterPipeline):
203+
"""Cluster pipeline for the module."""
204+
205+
206+
class AsyncPipeline(AsyncJSONCommands, AsyncioPipeline):
207+
"""Pipeline for the module."""

0 commit comments

Comments
 (0)