Skip to content

Commit 84e00ee

Browse files
committed
Use retry mechanism in async version of Connection objects
1 parent 4b0543d commit 84e00ee

File tree

4 files changed

+92
-5
lines changed

4 files changed

+92
-5
lines changed

redis/asyncio/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ def __init__(
637637
retry_on_error = []
638638
if retry_on_timeout:
639639
retry_on_error.append(TimeoutError)
640+
retry_on_error.append(socket.timeout)
641+
retry_on_error.append(asyncio.TimeoutError)
640642
self.retry_on_error = retry_on_error
641643
if retry_on_error:
642644
if not retry:
@@ -706,7 +708,9 @@ async def connect(self):
706708
if self.is_connected:
707709
return
708710
try:
709-
await self._connect()
711+
await self.retry.call_with_retry(
712+
lambda: self._connect(), lambda error: self.disconnect(error)
713+
)
710714
except asyncio.CancelledError:
711715
raise
712716
except (socket.timeout, asyncio.TimeoutError):
@@ -816,7 +820,7 @@ async def on_connect(self) -> None:
816820
if str_if_bytes(await self.read_response()) != "OK":
817821
raise ConnectionError("Invalid Database")
818822

819-
async def disconnect(self) -> None:
823+
async def disconnect(self, *args) -> None:
820824
"""Disconnects from the Redis server"""
821825
try:
822826
async with async_timeout.timeout(self.socket_connect_timeout):

redis/asyncio/sentinel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def connect_to(self, address):
4444
if str_if_bytes(await self.read_response()) != "PONG":
4545
raise ConnectionError("PING failed")
4646

47-
async def connect(self):
47+
async def _connect_retry(self):
4848
if self._reader:
4949
return # already connected
5050
if self.connection_pool.is_master:
@@ -57,6 +57,12 @@ async def connect(self):
5757
continue
5858
raise SlaveNotFoundError # Never be here
5959

60+
async def connect(self):
61+
return await self.retry.call_with_retry(
62+
self._connect_retry,
63+
lambda error: asyncio.sleep(0),
64+
)
65+
6066
async def read_response(self, disable_decoding: bool = False):
6167
try:
6268
return await super().read_response(disable_decoding=disable_decoding)

tests/test_asyncio/test_connection.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import asyncio
2+
import socket
23
import types
4+
from unittest.mock import patch
35

46
import pytest
57

6-
from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection
7-
from redis.exceptions import InvalidResponse
8+
from redis.asyncio.connection import Connection, PythonParser, UnixDomainSocketConnection
9+
from redis.asyncio.retry import Retry
10+
from redis.backoff import NoBackoff
11+
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
12+
813
from redis.utils import HIREDIS_AVAILABLE
914
from tests.conftest import skip_if_server_version_lt
1015

@@ -60,3 +65,44 @@ async def test_socket_param_regression(r):
6065
async def test_can_run_concurrent_commands(r):
6166
assert await r.ping() is True
6267
assert all(await asyncio.gather(*(r.ping() for _ in range(10))))
68+
69+
70+
async def test_connect_retry_on_timeout_error():
71+
"""Test that the _connect function is retried in case of a timeout"""
72+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3))
73+
origin_connect = conn._connect
74+
conn._connect = mock.AsyncMock()
75+
76+
async def mock_connect():
77+
# connect only on the last retry
78+
if conn._connect.call_count <= 2:
79+
raise socket.timeout
80+
else:
81+
return await origin_connect()
82+
83+
conn._connect.side_effect = mock_connect
84+
await conn.connect()
85+
assert conn._connect.call_count == 3
86+
87+
88+
async def test_connect_without_retry_on_os_error():
89+
"""Test that the _connect function is not being retried in case of a OSError"""
90+
with patch.object(Connection, "_connect") as _connect:
91+
_connect.side_effect = OSError("")
92+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
93+
with pytest.raises(ConnectionError):
94+
await conn.connect()
95+
assert _connect.call_count == 1
96+
97+
98+
async def test_connect_timeout_error_without_retry():
99+
"""Test that the _connect function is not being retried if retry_on_timeout is
100+
set to False"""
101+
conn = Connection(retry_on_timeout=False)
102+
conn._connect = mock.AsyncMock()
103+
conn._connect.side_effect = socket.timeout
104+
105+
with pytest.raises(TimeoutError) as e:
106+
await conn.connect()
107+
assert conn._connect.call_count == 1
108+
assert str(e.value) == "Timeout connecting to server"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import socket
2+
3+
import pytest
4+
5+
from redis.asyncio.retry import Retry
6+
from redis.asyncio.sentinel import SentinelManagedConnection
7+
from redis.backoff import NoBackoff
8+
9+
from .compat import mock
10+
11+
pytestmark = pytest.mark.asyncio
12+
13+
14+
async def test_connect_retry_on_timeout_error():
15+
"""Test that the _connect function is retried in case of a timeout"""
16+
connection_pool = mock.AsyncMock()
17+
connection_pool.get_master_address = mock.AsyncMock(return_value=("localhost", 6379))
18+
conn = SentinelManagedConnection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3), connection_pool=connection_pool)
19+
origin_connect = conn._connect
20+
conn._connect = mock.AsyncMock()
21+
22+
async def mock_connect():
23+
# connect only on the last retry
24+
if conn._connect.call_count <= 2:
25+
raise socket.timeout
26+
else:
27+
return await origin_connect()
28+
29+
conn._connect.side_effect = mock_connect
30+
await conn.connect()
31+
assert conn._connect.call_count == 3

0 commit comments

Comments
 (0)