Skip to content

Commit caafb98

Browse files
committed
add cluster "host_port_remap" feature
1 parent 45cabb5 commit caafb98

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

redis/asyncio/cluster.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import warnings
66
from typing import (
77
Any,
8+
Callable,
89
Deque,
910
Dict,
1011
Generator,
1112
List,
1213
Mapping,
1314
Optional,
15+
Tuple,
1416
Type,
1517
TypeVar,
1618
Union,
@@ -250,6 +252,7 @@ def __init__(
250252
ssl_certfile: Optional[str] = None,
251253
ssl_check_hostname: bool = False,
252254
ssl_keyfile: Optional[str] = None,
255+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
253256
) -> None:
254257
if db:
255258
raise RedisClusterException(
@@ -337,7 +340,12 @@ def __init__(
337340
if host and port:
338341
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))
339342

340-
self.nodes_manager = NodesManager(startup_nodes, require_full_coverage, kwargs)
343+
self.nodes_manager = NodesManager(
344+
startup_nodes,
345+
require_full_coverage,
346+
kwargs,
347+
host_port_remap=host_port_remap,
348+
)
341349
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
342350
self.read_from_replicas = read_from_replicas
343351
self.reinitialize_steps = reinitialize_steps
@@ -1044,17 +1052,20 @@ class NodesManager:
10441052
"require_full_coverage",
10451053
"slots_cache",
10461054
"startup_nodes",
1055+
"host_port_remap",
10471056
)
10481057

10491058
def __init__(
10501059
self,
10511060
startup_nodes: List["ClusterNode"],
10521061
require_full_coverage: bool,
10531062
connection_kwargs: Dict[str, Any],
1063+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
10541064
) -> None:
10551065
self.startup_nodes = {node.name: node for node in startup_nodes}
10561066
self.require_full_coverage = require_full_coverage
10571067
self.connection_kwargs = connection_kwargs
1068+
self.host_port_remap = host_port_remap
10581069

10591070
self.default_node: "ClusterNode" = None
10601071
self.nodes_cache: Dict[str, "ClusterNode"] = {}
@@ -1213,6 +1224,7 @@ async def initialize(self) -> None:
12131224
if host == "":
12141225
host = startup_node.host
12151226
port = int(primary_node[1])
1227+
host, port = self.remap_host_port(host, port)
12161228

12171229
target_node = tmp_nodes_cache.get(get_node_name(host, port))
12181230
if not target_node:
@@ -1231,6 +1243,7 @@ async def initialize(self) -> None:
12311243
for replica_node in replica_nodes:
12321244
host = replica_node[0]
12331245
port = replica_node[1]
1246+
host, port = self.remap_host_port(host, port)
12341247

12351248
target_replica_node = tmp_nodes_cache.get(
12361249
get_node_name(host, port)
@@ -1304,6 +1317,16 @@ async def close(self, attr: str = "nodes_cache") -> None:
13041317
)
13051318
)
13061319

1320+
def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
1321+
"""
1322+
Remap the host and port returned from the cluster to a different
1323+
internal value. Useful if the client is not connecting directly
1324+
to the cluster.
1325+
"""
1326+
if self.host_port_remap:
1327+
return self.host_port_remap(host, port)
1328+
return host, port
1329+
13071330

13081331
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
13091332
"""

0 commit comments

Comments
 (0)