Skip to content

Commit f9f9d06

Browse files
automatically reconnect pubsub when reading messages in blocking mode (#2281)
* optimistic default info on test sessionstart. Makes test discovery work, even without a redis connection. * Add unittests verifying that (non-async) PubSub will automatically reconnect * Add tests for asyncio pubsub subsciription auto-reconnect * automatically connect for blocking reads (asyncio) * fix automatic connect on blocking pubsub read (non-async) * lint & format * Perform `connect()` call in PubSub code rather than `read_response`.
1 parent 48f5aca commit f9f9d06

File tree

6 files changed

+298
-12
lines changed

6 files changed

+298
-12
lines changed

redis/asyncio/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,9 +754,15 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
754754

755755
await self.check_health()
756756

757-
if not block and not await self._execute(conn, conn.can_read, timeout=timeout):
758-
return None
759-
response = await self._execute(conn, conn.read_response)
757+
async def try_read():
758+
if not block:
759+
if not await conn.can_read(timeout=timeout):
760+
return None
761+
else:
762+
await conn.connect()
763+
return await conn.read_response()
764+
765+
response = await self._execute(conn, try_read)
760766

761767
if conn.health_check_interval and response == self.health_check_response:
762768
# ignore the health check message as user might not expect it

redis/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,9 +1497,15 @@ def parse_response(self, block=True, timeout=0):
14971497

14981498
self.check_health()
14991499

1500-
if not block and not self._execute(conn, conn.can_read, timeout=timeout):
1501-
return None
1502-
response = self._execute(conn, conn.read_response)
1500+
def try_read():
1501+
if not block:
1502+
if not conn.can_read(timeout=timeout):
1503+
return None
1504+
else:
1505+
conn.connect()
1506+
return conn.read_response()
1507+
1508+
response = self._execute(conn, try_read)
15031509

15041510
if self.is_health_check_response(response):
15051511
# ignore the health check message as user might not expect it

tests/conftest.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,25 @@ def _get_info(redis_url):
130130

131131

132132
def pytest_sessionstart(session):
133+
# during test discovery, e.g. with VS Code, we may not
134+
# have a server running.
133135
redis_url = session.config.getoption("--redis-url")
134-
info = _get_info(redis_url)
135-
version = info["redis_version"]
136-
arch_bits = info["arch_bits"]
137-
cluster_enabled = info["cluster_enabled"]
136+
try:
137+
info = _get_info(redis_url)
138+
version = info["redis_version"]
139+
arch_bits = info["arch_bits"]
140+
cluster_enabled = info["cluster_enabled"]
141+
enterprise = info["enterprise"]
142+
except redis.ConnectionError:
143+
# provide optimistic defaults
144+
version = "10.0.0"
145+
arch_bits = 64
146+
cluster_enabled = False
147+
enterprise = False
138148
REDIS_INFO["version"] = version
139149
REDIS_INFO["arch_bits"] = arch_bits
140150
REDIS_INFO["cluster_enabled"] = cluster_enabled
141-
REDIS_INFO["enterprise"] = info["enterprise"]
151+
REDIS_INFO["enterprise"] = enterprise
142152
# store REDIS_INFO in config so that it is available from "condition strings"
143153
session.config.REDIS_INFO = REDIS_INFO
144154

tests/test_asyncio/compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
import asyncio
2+
import sys
13
from unittest import mock
24

35
try:
46
mock.AsyncMock
57
except AttributeError:
68
import mock
9+
10+
11+
def create_task(coroutine):
12+
if sys.version_info[:2] >= (3, 7):
13+
return asyncio.create_task(coroutine)
14+
else:
15+
return asyncio.ensure_future(coroutine)

tests/test_asyncio/test_pubsub.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import functools
3+
import socket
34
from typing import Optional
45

56
import async_timeout
@@ -11,7 +12,7 @@
1112
from redis.typing import EncodableT
1213
from tests.conftest import skip_if_server_version_lt
1314

14-
from .compat import mock
15+
from .compat import create_task, mock
1516

1617

1718
def with_timeout(t):
@@ -786,3 +787,130 @@ def callback(message):
786787
"pattern": None,
787788
"type": "message",
788789
}
790+
791+
792+
# @pytest.mark.xfail
793+
@pytest.mark.parametrize("method", ["get_message", "listen"])
794+
@pytest.mark.onlynoncluster
795+
class TestPubSubAutoReconnect:
796+
timeout = 2
797+
798+
async def mysetup(self, r, method):
799+
self.messages = asyncio.Queue()
800+
self.pubsub = r.pubsub()
801+
# State: 0 = initial state , 1 = after disconnect, 2 = ConnectionError is seen,
802+
# 3=successfully reconnected 4 = exit
803+
self.state = 0
804+
self.cond = asyncio.Condition()
805+
if method == "get_message":
806+
self.get_message = self.loop_step_get_message
807+
else:
808+
self.get_message = self.loop_step_listen
809+
810+
self.task = create_task(self.loop())
811+
# get the initial connect message
812+
message = await self.messages.get()
813+
assert message == {
814+
"channel": b"foo",
815+
"data": 1,
816+
"pattern": None,
817+
"type": "subscribe",
818+
}
819+
820+
async def mycleanup(self):
821+
message = await self.messages.get()
822+
assert message == {
823+
"channel": b"foo",
824+
"data": 1,
825+
"pattern": None,
826+
"type": "subscribe",
827+
}
828+
# kill thread
829+
async with self.cond:
830+
self.state = 4 # quit
831+
await self.task
832+
833+
async def test_reconnect_socket_error(self, r: redis.Redis, method):
834+
"""
835+
Test that a socket error will cause reconnect
836+
"""
837+
async with async_timeout.timeout(self.timeout):
838+
await self.mysetup(r, method)
839+
# now, disconnect the connection, and wait for it to be re-established
840+
async with self.cond:
841+
assert self.state == 0
842+
self.state = 1
843+
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
844+
mockobj.read_response.side_effect = socket.error
845+
mockobj.can_read.side_effect = socket.error
846+
# wait until task noticies the disconnect until we undo the patch
847+
await self.cond.wait_for(lambda: self.state >= 2)
848+
assert not self.pubsub.connection.is_connected
849+
# it is in a disconnecte state
850+
# wait for reconnect
851+
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
852+
assert self.state == 3
853+
854+
await self.mycleanup()
855+
856+
async def test_reconnect_disconnect(self, r: redis.Redis, method):
857+
"""
858+
Test that a manual disconnect() will cause reconnect
859+
"""
860+
async with async_timeout.timeout(self.timeout):
861+
await self.mysetup(r, method)
862+
# now, disconnect the connection, and wait for it to be re-established
863+
async with self.cond:
864+
self.state = 1
865+
await self.pubsub.connection.disconnect()
866+
assert not self.pubsub.connection.is_connected
867+
# wait for reconnect
868+
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
869+
assert self.state == 3
870+
871+
await self.mycleanup()
872+
873+
async def loop(self):
874+
# reader loop, performing state transitions as it
875+
# discovers disconnects and reconnects
876+
await self.pubsub.subscribe("foo")
877+
while True:
878+
await asyncio.sleep(0.01) # give main thread chance to get lock
879+
async with self.cond:
880+
old_state = self.state
881+
try:
882+
if self.state == 4:
883+
break
884+
# print("state a ", self.state)
885+
got_msg = await self.get_message()
886+
assert got_msg
887+
if self.state in (1, 2):
888+
self.state = 3 # successful reconnect
889+
except redis.ConnectionError:
890+
assert self.state in (1, 2)
891+
self.state = 2 # signal that we noticed the disconnect
892+
finally:
893+
self.cond.notify()
894+
# make sure that we did notice the connection error
895+
# or reconnected without any error
896+
if old_state == 1:
897+
assert self.state in (2, 3)
898+
899+
async def loop_step_get_message(self):
900+
# get a single message via get_message
901+
message = await self.pubsub.get_message(timeout=0.1)
902+
# print(message)
903+
if message is not None:
904+
await self.messages.put(message)
905+
return True
906+
return False
907+
908+
async def loop_step_listen(self):
909+
# get a single message via listen()
910+
try:
911+
async with async_timeout.timeout(0.1):
912+
async for message in self.pubsub.listen():
913+
await self.messages.put(message)
914+
return True
915+
except asyncio.TimeoutError:
916+
return False

tests/test_pubsub.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import platform
2+
import queue
3+
import socket
24
import threading
35
import time
46
from unittest import mock
@@ -608,3 +610,128 @@ def test_pubsub_deadlock(self, master_host):
608610
p = r.pubsub()
609611
p.subscribe("my-channel-1", "my-channel-2")
610612
pool.reset()
613+
614+
615+
@pytest.mark.timeout(5, method="thread")
616+
@pytest.mark.parametrize("method", ["get_message", "listen"])
617+
@pytest.mark.onlynoncluster
618+
class TestPubSubAutoReconnect:
619+
def mysetup(self, r, method):
620+
self.messages = queue.Queue()
621+
self.pubsub = r.pubsub()
622+
self.state = 0
623+
self.cond = threading.Condition()
624+
if method == "get_message":
625+
self.get_message = self.loop_step_get_message
626+
else:
627+
self.get_message = self.loop_step_listen
628+
629+
self.thread = threading.Thread(target=self.loop)
630+
self.thread.daemon = True
631+
self.thread.start()
632+
# get the initial connect message
633+
message = self.messages.get(timeout=1)
634+
assert message == {
635+
"channel": b"foo",
636+
"data": 1,
637+
"pattern": None,
638+
"type": "subscribe",
639+
}
640+
641+
def wait_for_reconnect(self):
642+
self.cond.wait_for(lambda: self.pubsub.connection._sock is not None, timeout=2)
643+
assert self.pubsub.connection._sock is not None # we didn't time out
644+
assert self.state == 3
645+
646+
message = self.messages.get(timeout=1)
647+
assert message == {
648+
"channel": b"foo",
649+
"data": 1,
650+
"pattern": None,
651+
"type": "subscribe",
652+
}
653+
654+
def mycleanup(self):
655+
# kill thread
656+
with self.cond:
657+
self.state = 4 # quit
658+
self.cond.notify()
659+
self.thread.join()
660+
661+
def test_reconnect_socket_error(self, r: redis.Redis, method):
662+
"""
663+
Test that a socket error will cause reconnect
664+
"""
665+
self.mysetup(r, method)
666+
try:
667+
# now, disconnect the connection, and wait for it to be re-established
668+
with self.cond:
669+
self.state = 1
670+
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
671+
mockobj.read_response.side_effect = socket.error
672+
mockobj.can_read.side_effect = socket.error
673+
# wait until thread notices the disconnect until we undo the patch
674+
self.cond.wait_for(lambda: self.state >= 2)
675+
assert (
676+
self.pubsub.connection._sock is None
677+
) # it is in a disconnected state
678+
self.wait_for_reconnect()
679+
680+
finally:
681+
self.mycleanup()
682+
683+
def test_reconnect_disconnect(self, r: redis.Redis, method):
684+
"""
685+
Test that a manual disconnect() will cause reconnect
686+
"""
687+
self.mysetup(r, method)
688+
try:
689+
# now, disconnect the connection, and wait for it to be re-established
690+
with self.cond:
691+
self.state = 1
692+
self.pubsub.connection.disconnect()
693+
assert self.pubsub.connection._sock is None
694+
# wait for reconnect
695+
self.wait_for_reconnect()
696+
finally:
697+
self.mycleanup()
698+
699+
def loop(self):
700+
# reader loop, performing state transitions as it
701+
# discovers disconnects and reconnects
702+
self.pubsub.subscribe("foo")
703+
while True:
704+
time.sleep(0.01) # give main thread chance to get lock
705+
with self.cond:
706+
old_state = self.state
707+
try:
708+
if self.state == 4:
709+
break
710+
# print ('state, %s, sock %s' % (state, pubsub.connection._sock))
711+
got_msg = self.get_message()
712+
assert got_msg
713+
if self.state in (1, 2):
714+
self.state = 3 # successful reconnect
715+
except redis.ConnectionError:
716+
assert self.state in (1, 2)
717+
self.state = 2
718+
finally:
719+
self.cond.notify()
720+
# assert that we noticed a connect error, or automatically
721+
# reconnected without error
722+
if old_state == 1:
723+
assert self.state in (2, 3)
724+
725+
def loop_step_get_message(self):
726+
# get a single message via listen()
727+
message = self.pubsub.get_message(timeout=0.1)
728+
if message is not None:
729+
self.messages.put(message)
730+
return True
731+
return False
732+
733+
def loop_step_listen(self):
734+
# get a single message via listen()
735+
for message in self.pubsub.listen():
736+
self.messages.put(message)
737+
return True

0 commit comments

Comments
 (0)