Skip to content

Commit 6f6b6f6

Browse files
committed
Change host_port_remap into a callable
1 parent fe72d7b commit 6f6b6f6

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

redis/asyncio/cluster.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from typing import (
77
Any,
8+
Callable,
89
Deque,
910
Dict,
1011
Generator,
@@ -251,7 +252,7 @@ def __init__(
251252
ssl_certfile: Optional[str] = None,
252253
ssl_check_hostname: bool = False,
253254
ssl_keyfile: Optional[str] = None,
254-
host_port_remap: List[Dict[str, Any]] = [],
255+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
255256
) -> None:
256257
if db:
257258
raise RedisClusterException(
@@ -1059,7 +1060,7 @@ def __init__(
10591060
startup_nodes: List["ClusterNode"],
10601061
require_full_coverage: bool,
10611062
connection_kwargs: Dict[str, Any],
1062-
host_port_remap: List[Dict[str, Any]] = [],
1063+
host_port_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None,
10631064
) -> None:
10641065
self.startup_nodes = {node.name: node for node in startup_nodes}
10651066
self.require_full_coverage = require_full_coverage
@@ -1322,22 +1323,8 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
13221323
internal value. Useful if the client is not connecting directly
13231324
to the cluster.
13241325
"""
1325-
for map_entry in self.host_port_remap:
1326-
mapped = False
1327-
if "from_host" in map_entry:
1328-
if host != map_entry["from_host"]:
1329-
continue
1330-
else:
1331-
host = map_entry["to_host"]
1332-
mapped = True
1333-
if "from_port" in map_entry:
1334-
if port != map_entry["from_port"]:
1335-
continue
1336-
else:
1337-
port = map_entry["to_port"]
1338-
mapped = True
1339-
if mapped:
1340-
break
1326+
if self.host_port_remap:
1327+
return self.host_port_remap(host, port)
13411328
return host, port
13421329

13431330

tests/test_asyncio/test_cwe_404.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,37 +181,39 @@ async def test_cluster(request, redis_addr):
181181
remap_base = 7372
182182
n_nodes = 6
183183

184-
remap = []
184+
def remap(host, port):
185+
return host, remap_base + port - cluster_port
186+
185187
proxies = []
186188
for i in range(n_nodes):
187189
port = cluster_port + i
188190
remapped = remap_base + i
189-
remap.append({"from_port": port, "to_port": remapped})
190191
forward_addr = redis_addr[0], port
191192
proxy = DelayProxy(
192193
addr=("127.0.0.1", remapped), redis_addr=forward_addr, delay=0
193194
)
194195
proxies.append(proxy)
195196

196-
# start proxies
197-
await asyncio.gather(*[p.start() for p in proxies])
198-
197+
# helpers to work with all or any proxy
199198
def all_clear():
200199
for p in proxies:
201200
p.send_event.clear()
202201

203-
async def wait_for_send():
202+
async def any_wait():
204203
asyncio.wait(
205204
[p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED
206205
)
207206

208207
@contextlib.contextmanager
209-
def override(delay: int = 0):
208+
def all_override(delay: int = 0):
210209
with contextlib.ExitStack() as stack:
211210
for p in proxies:
212211
stack.enter_context(p.override(delay=delay))
213212
yield
214213

214+
# start proxies
215+
await asyncio.gather(*[p.start() for p in proxies])
216+
215217
with contextlib.closing(
216218
RedisCluster.from_url(f"redis://127.0.0.1:{remap_base}", host_port_remap=remap)
217219
) as r:
@@ -220,10 +222,10 @@ def override(delay: int = 0):
220222
await r.set("bar", "bar")
221223

222224
all_clear()
223-
with override(delay=delay):
225+
with all_override(delay=delay):
224226
t = asyncio.create_task(r.get("foo"))
225227
# cannot wait on the send event, we don't know which node will be used
226-
await wait_for_send()
228+
await any_wait()
227229
await asyncio.sleep(delay)
228230
t.cancel()
229231
with pytest.raises(asyncio.CancelledError):
@@ -237,4 +239,5 @@ async def doit():
237239

238240
await asyncio.gather(*[doit() for _ in range(10)])
239241

242+
# stop proxies
240243
await asyncio.gather(*(p.stop() for p in proxies))

0 commit comments

Comments
 (0)