5
5
import warnings
6
6
from typing import (
7
7
Any ,
8
+ Callable ,
8
9
Deque ,
9
10
Dict ,
10
11
Generator ,
11
12
List ,
12
13
Mapping ,
13
14
Optional ,
15
+ Tuple ,
14
16
Type ,
15
17
TypeVar ,
16
18
Union ,
@@ -250,6 +252,7 @@ def __init__(
250
252
ssl_certfile : Optional [str ] = None ,
251
253
ssl_check_hostname : bool = False ,
252
254
ssl_keyfile : Optional [str ] = None ,
255
+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
253
256
) -> None :
254
257
if db :
255
258
raise RedisClusterException (
@@ -337,7 +340,12 @@ def __init__(
337
340
if host and port :
338
341
startup_nodes .append (ClusterNode (host , port , ** self .connection_kwargs ))
339
342
340
- self .nodes_manager = NodesManager (startup_nodes , require_full_coverage , kwargs )
343
+ self .nodes_manager = NodesManager (
344
+ startup_nodes ,
345
+ require_full_coverage ,
346
+ kwargs ,
347
+ host_port_remap = host_port_remap ,
348
+ )
341
349
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
342
350
self .read_from_replicas = read_from_replicas
343
351
self .reinitialize_steps = reinitialize_steps
@@ -1044,17 +1052,20 @@ class NodesManager:
1044
1052
"require_full_coverage" ,
1045
1053
"slots_cache" ,
1046
1054
"startup_nodes" ,
1055
+ "host_port_remap" ,
1047
1056
)
1048
1057
1049
1058
def __init__ (
1050
1059
self ,
1051
1060
startup_nodes : List ["ClusterNode" ],
1052
1061
require_full_coverage : bool ,
1053
1062
connection_kwargs : Dict [str , Any ],
1063
+ host_port_remap : Optional [Callable [[str , int ], Tuple [str , int ]]] = None ,
1054
1064
) -> None :
1055
1065
self .startup_nodes = {node .name : node for node in startup_nodes }
1056
1066
self .require_full_coverage = require_full_coverage
1057
1067
self .connection_kwargs = connection_kwargs
1068
+ self .host_port_remap = host_port_remap
1058
1069
1059
1070
self .default_node : "ClusterNode" = None
1060
1071
self .nodes_cache : Dict [str , "ClusterNode" ] = {}
@@ -1213,6 +1224,7 @@ async def initialize(self) -> None:
1213
1224
if host == "" :
1214
1225
host = startup_node .host
1215
1226
port = int (primary_node [1 ])
1227
+ host , port = self .remap_host_port (host , port )
1216
1228
1217
1229
target_node = tmp_nodes_cache .get (get_node_name (host , port ))
1218
1230
if not target_node :
@@ -1231,6 +1243,7 @@ async def initialize(self) -> None:
1231
1243
for replica_node in replica_nodes :
1232
1244
host = replica_node [0 ]
1233
1245
port = replica_node [1 ]
1246
+ host , port = self .remap_host_port (host , port )
1234
1247
1235
1248
target_replica_node = tmp_nodes_cache .get (
1236
1249
get_node_name (host , port )
@@ -1304,6 +1317,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
1304
1317
)
1305
1318
)
1306
1319
1320
+ def remap_host_port (self , host : str , port : int ) -> Tuple [str , int ]:
1321
+ """
1322
+ Remap the host and port returned from the cluster to a different
1323
+ internal value. Useful if the client is not connecting directly
1324
+ to the cluster.
1325
+ """
1326
+ if self .host_port_remap :
1327
+ return self .host_port_remap (host , port )
1328
+ return host , port
1329
+
1307
1330
1308
1331
class ClusterPipeline (AbstractRedis , AbstractRedisCluster , AsyncRedisClusterCommands ):
1309
1332
"""
0 commit comments