Skip to content

Commit f46829c

Browse files
dvora-hleibale
andauthored
Sharded pubsub (#2762)
* sharded pubsub * sharded pubsub Co-authored-by: Leibale Eidelman <[email protected]> * Shrded Pubsub TestPubSubSubscribeUnsubscribe * fix TestPubSubSubscribeUnsubscribe * more tests * linters * TestPubSubSubcommands * fix @leibale comments * linters * fix @chayim comments --------- Co-authored-by: Leibale Eidelman <[email protected]>
1 parent 312118b commit f46829c

File tree

6 files changed

+559
-37
lines changed

6 files changed

+559
-37
lines changed

redis/client.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ class AbstractRedis:
833833
"QUIT": bool_ok,
834834
"STRALGO": parse_stralgo,
835835
"PUBSUB NUMSUB": parse_pubsub_numsub,
836+
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
836837
"RANDOMKEY": lambda r: r and r or None,
837838
"RESET": str_if_bytes,
838839
"SCAN": parse_scan,
@@ -1440,8 +1441,8 @@ class PubSub:
14401441
will be returned and it's safe to start listening again.
14411442
"""
14421443

1443-
PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
1444-
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
1444+
PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
1445+
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
14451446
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
14461447

14471448
def __init__(
@@ -1493,9 +1494,11 @@ def reset(self):
14931494
self.connection.clear_connect_callbacks()
14941495
self.connection_pool.release(self.connection)
14951496
self.connection = None
1496-
self.channels = {}
14971497
self.health_check_response_counter = 0
1498+
self.channels = {}
14981499
self.pending_unsubscribe_channels = set()
1500+
self.shard_channels = {}
1501+
self.pending_unsubscribe_shard_channels = set()
14991502
self.patterns = {}
15001503
self.pending_unsubscribe_patterns = set()
15011504
self.subscribed_event.clear()
@@ -1510,16 +1513,23 @@ def on_connect(self, connection):
15101513
# before passing them to [p]subscribe.
15111514
self.pending_unsubscribe_channels.clear()
15121515
self.pending_unsubscribe_patterns.clear()
1516+
self.pending_unsubscribe_shard_channels.clear()
15131517
if self.channels:
1514-
channels = {}
1515-
for k, v in self.channels.items():
1516-
channels[self.encoder.decode(k, force=True)] = v
1518+
channels = {
1519+
self.encoder.decode(k, force=True): v for k, v in self.channels.items()
1520+
}
15171521
self.subscribe(**channels)
15181522
if self.patterns:
1519-
patterns = {}
1520-
for k, v in self.patterns.items():
1521-
patterns[self.encoder.decode(k, force=True)] = v
1523+
patterns = {
1524+
self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
1525+
}
15221526
self.psubscribe(**patterns)
1527+
if self.shard_channels:
1528+
shard_channels = {
1529+
self.encoder.decode(k, force=True): v
1530+
for k, v in self.shard_channels.items()
1531+
}
1532+
self.ssubscribe(**shard_channels)
15231533

15241534
@property
15251535
def subscribed(self):
@@ -1728,6 +1738,45 @@ def unsubscribe(self, *args):
17281738
self.pending_unsubscribe_channels.update(channels)
17291739
return self.execute_command("UNSUBSCRIBE", *args)
17301740

1741+
def ssubscribe(self, *args, target_node=None, **kwargs):
1742+
"""
1743+
Subscribes the client to the specified shard channels.
1744+
Channels supplied as keyword arguments expect a channel name as the key
1745+
and a callable as the value. A channel's callable will be invoked automatically
1746+
when a message is received on that channel rather than producing a message via
1747+
``listen()`` or ``get_sharded_message()``.
1748+
"""
1749+
if args:
1750+
args = list_or_args(args[0], args[1:])
1751+
new_s_channels = dict.fromkeys(args)
1752+
new_s_channels.update(kwargs)
1753+
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1754+
# update the s_channels dict AFTER we send the command. we don't want to
1755+
# subscribe twice to these channels, once for the command and again
1756+
# for the reconnection.
1757+
new_s_channels = self._normalize_keys(new_s_channels)
1758+
self.shard_channels.update(new_s_channels)
1759+
if not self.subscribed:
1760+
# Set the subscribed_event flag to True
1761+
self.subscribed_event.set()
1762+
# Clear the health check counter
1763+
self.health_check_response_counter = 0
1764+
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1765+
return ret_val
1766+
1767+
def sunsubscribe(self, *args, target_node=None):
1768+
"""
1769+
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1770+
all shard_channels
1771+
"""
1772+
if args:
1773+
args = list_or_args(args[0], args[1:])
1774+
s_channels = self._normalize_keys(dict.fromkeys(args))
1775+
else:
1776+
s_channels = self.shard_channels
1777+
self.pending_unsubscribe_shard_channels.update(s_channels)
1778+
return self.execute_command("SUNSUBSCRIBE", *args)
1779+
17311780
def listen(self):
17321781
"Listen for messages on channels this client has been subscribed to"
17331782
while self.subscribed:
@@ -1762,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
17621811
return self.handle_message(response, ignore_subscribe_messages)
17631812
return None
17641813

1814+
get_sharded_message = get_message
1815+
17651816
def ping(self, message=None):
17661817
"""
17671818
Ping the Redis server
@@ -1809,12 +1860,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
18091860
if pattern in self.pending_unsubscribe_patterns:
18101861
self.pending_unsubscribe_patterns.remove(pattern)
18111862
self.patterns.pop(pattern, None)
1863+
elif message_type == "sunsubscribe":
1864+
s_channel = response[1]
1865+
if s_channel in self.pending_unsubscribe_shard_channels:
1866+
self.pending_unsubscribe_shard_channels.remove(s_channel)
1867+
self.shard_channels.pop(s_channel, None)
18121868
else:
18131869
channel = response[1]
18141870
if channel in self.pending_unsubscribe_channels:
18151871
self.pending_unsubscribe_channels.remove(channel)
18161872
self.channels.pop(channel, None)
1817-
if not self.channels and not self.patterns:
1873+
if not self.channels and not self.patterns and not self.shard_channels:
18181874
# There are no subscriptions anymore, set subscribed_event flag
18191875
# to false
18201876
self.subscribed_event.clear()
@@ -1823,6 +1879,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
18231879
# if there's a message handler, invoke it
18241880
if message_type == "pmessage":
18251881
handler = self.patterns.get(message["pattern"], None)
1882+
elif message_type == "smessage":
1883+
handler = self.shard_channels.get(message["channel"], None)
18261884
else:
18271885
handler = self.channels.get(message["channel"], None)
18281886
if handler:
@@ -1843,6 +1901,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
18431901
for pattern, handler in self.patterns.items():
18441902
if handler is None:
18451903
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1904+
for s_channel, handler in self.shard_channels.items():
1905+
if handler is None:
1906+
raise PubSubError(
1907+
f"Shard Channel: '{s_channel}' has no handler registered"
1908+
)
18461909

18471910
thread = PubSubWorkerThread(
18481911
self, sleep_time, daemon=daemon, exception_handler=exception_handler

redis/cluster.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from redis.backoff import default_backoff
1010
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
1111
from redis.commands import READ_COMMANDS, RedisClusterCommands
12+
from redis.commands.helpers import list_or_args
1213
from redis.connection import ConnectionPool, DefaultParser, parse_url
1314
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
1415
from redis.exceptions import (
@@ -222,6 +223,8 @@ class AbstractRedisCluster:
222223
"PUBSUB CHANNELS",
223224
"PUBSUB NUMPAT",
224225
"PUBSUB NUMSUB",
226+
"PUBSUB SHARDCHANNELS",
227+
"PUBSUB SHARDNUMSUB",
225228
"PING",
226229
"INFO",
227230
"SHUTDOWN",
@@ -346,11 +349,13 @@ class AbstractRedisCluster:
346349
}
347350

348351
RESULT_CALLBACKS = dict_merge(
349-
list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub),
352+
list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub),
350353
list_keys_to_dict(
351354
["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values()))
352355
),
353-
list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result),
356+
list_keys_to_dict(
357+
["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result
358+
),
354359
list_keys_to_dict(
355360
[
356361
"PING",
@@ -1625,6 +1630,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
16251630
else redis_cluster.get_redis_connection(self.node).connection_pool
16261631
)
16271632
self.cluster = redis_cluster
1633+
self.node_pubsub_mapping = {}
1634+
self._pubsubs_generator = self._pubsubs_generator()
16281635
super().__init__(
16291636
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
16301637
)
@@ -1678,9 +1685,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
16781685
f"Node {host}:{port} doesn't exist in the cluster"
16791686
)
16801687

1681-
def execute_command(self, *args, **kwargs):
1688+
def execute_command(self, *args):
16821689
"""
1683-
Execute a publish/subscribe command.
1690+
Execute a subscribe/unsubscribe command.
16841691
16851692
Taken code from redis-py and tweak to make it work within a cluster.
16861693
"""
@@ -1713,13 +1720,103 @@ def execute_command(self, *args, **kwargs):
17131720
connection = self.connection
17141721
self._execute(connection, connection.send_command, *args)
17151722

1723+
def _get_node_pubsub(self, node):
1724+
try:
1725+
return self.node_pubsub_mapping[node.name]
1726+
except KeyError:
1727+
pubsub = node.redis_connection.pubsub()
1728+
self.node_pubsub_mapping[node.name] = pubsub
1729+
return pubsub
1730+
1731+
def _sharded_message_generator(self):
1732+
for _ in range(len(self.node_pubsub_mapping)):
1733+
pubsub = next(self._pubsubs_generator)
1734+
message = pubsub.get_message()
1735+
if message is not None:
1736+
return message
1737+
return None
1738+
1739+
def _pubsubs_generator(self):
1740+
while True:
1741+
for pubsub in self.node_pubsub_mapping.values():
1742+
yield pubsub
1743+
1744+
def get_sharded_message(
1745+
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
1746+
):
1747+
if target_node:
1748+
message = self.node_pubsub_mapping[target_node.name].get_message(
1749+
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
1750+
)
1751+
else:
1752+
message = self._sharded_message_generator()
1753+
if message is None:
1754+
return None
1755+
elif str_if_bytes(message["type"]) == "sunsubscribe":
1756+
if message["channel"] in self.pending_unsubscribe_shard_channels:
1757+
self.pending_unsubscribe_shard_channels.remove(message["channel"])
1758+
self.shard_channels.pop(message["channel"], None)
1759+
node = self.cluster.get_node_from_key(message["channel"])
1760+
if self.node_pubsub_mapping[node.name].subscribed is False:
1761+
self.node_pubsub_mapping.pop(node.name)
1762+
if not self.channels and not self.patterns and not self.shard_channels:
1763+
# There are no subscriptions anymore, set subscribed_event flag
1764+
# to false
1765+
self.subscribed_event.clear()
1766+
if self.ignore_subscribe_messages or ignore_subscribe_messages:
1767+
return None
1768+
return message
1769+
1770+
def ssubscribe(self, *args, **kwargs):
1771+
if args:
1772+
args = list_or_args(args[0], args[1:])
1773+
s_channels = dict.fromkeys(args)
1774+
s_channels.update(kwargs)
1775+
for s_channel, handler in s_channels.items():
1776+
node = self.cluster.get_node_from_key(s_channel)
1777+
pubsub = self._get_node_pubsub(node)
1778+
if handler:
1779+
pubsub.ssubscribe(**{s_channel: handler})
1780+
else:
1781+
pubsub.ssubscribe(s_channel)
1782+
self.shard_channels.update(pubsub.shard_channels)
1783+
self.pending_unsubscribe_shard_channels.difference_update(
1784+
self._normalize_keys({s_channel: None})
1785+
)
1786+
if pubsub.subscribed and not self.subscribed:
1787+
self.subscribed_event.set()
1788+
self.health_check_response_counter = 0
1789+
1790+
def sunsubscribe(self, *args):
1791+
if args:
1792+
args = list_or_args(args[0], args[1:])
1793+
else:
1794+
args = self.shard_channels
1795+
1796+
for s_channel in args:
1797+
node = self.cluster.get_node_from_key(s_channel)
1798+
p = self._get_node_pubsub(node)
1799+
p.sunsubscribe(s_channel)
1800+
self.pending_unsubscribe_shard_channels.update(
1801+
p.pending_unsubscribe_shard_channels
1802+
)
1803+
17161804
def get_redis_connection(self):
17171805
"""
17181806
Get the Redis connection of the pubsub connected node.
17191807
"""
17201808
if self.node is not None:
17211809
return self.node.redis_connection
17221810

1811+
def disconnect(self):
1812+
"""
1813+
Disconnect the pubsub connection.
1814+
"""
1815+
if self.connection:
1816+
self.connection.disconnect()
1817+
for pubsub in self.node_pubsub_mapping.values():
1818+
pubsub.connection.disconnect()
1819+
17231820

17241821
class ClusterPipeline(RedisCluster):
17251822
"""

redis/commands/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5103,6 +5103,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT
51035103
"""
51045104
return self.execute_command("PUBLISH", channel, message, **kwargs)
51055105

5106+
def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
5107+
"""
5108+
Posts a message to the given shard channel.
5109+
Returns the number of clients that received the message
5110+
5111+
For more information see https://redis.io/commands/spublish
5112+
"""
5113+
return self.execute_command("SPUBLISH", shard_channel, message)
5114+
51065115
def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
51075116
"""
51085117
Return a list of channels that have at least one subscriber
@@ -5111,6 +5120,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
51115120
"""
51125121
return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs)
51135122

5123+
def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
5124+
"""
5125+
Return a list of shard_channels that have at least one subscriber
5126+
5127+
For more information see https://redis.io/commands/pubsub-shardchannels
5128+
"""
5129+
return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs)
5130+
51145131
def pubsub_numpat(self, **kwargs) -> ResponseT:
51155132
"""
51165133
Returns the number of subscriptions to patterns
@@ -5128,6 +5145,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT:
51285145
"""
51295146
return self.execute_command("PUBSUB NUMSUB", *args, **kwargs)
51305147

5148+
def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT:
5149+
"""
5150+
Return a list of (shard_channel, number of subscribers) tuples
5151+
for each channel given in ``*args``
5152+
5153+
For more information see https://redis.io/commands/pubsub-shardnumsub
5154+
"""
5155+
return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs)
5156+
51315157

51325158
AsyncPubSubCommands = PubSubCommands
51335159

redis/parsers/commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ def _get_pubsub_keys(self, *args):
155155
# the second argument is a part of the command name, e.g.
156156
# ['PUBSUB', 'NUMSUB', 'foo'].
157157
pubsub_type = args[1].upper()
158-
if pubsub_type in ["CHANNELS", "NUMSUB"]:
158+
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
159159
keys = args[2:]
160160
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
161161
# format example:
162162
# SUBSCRIBE channel [channel ...]
163163
keys = list(args[1:])
164-
elif command == "PUBLISH":
164+
elif command in ["PUBLISH", "SPUBLISH"]:
165165
# format example:
166166
# PUBLISH channel message
167167
keys = [args[1]]

0 commit comments

Comments
 (0)