Skip to content

Commit 15b8797

Browse files
committed
Use retry mechanism in async version of Connection objects
1 parent 4b0543d commit 15b8797

File tree

5 files changed

+102
-5
lines changed

5 files changed

+102
-5
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
* Add retry mechanism to async version of Connection
23
* Compare commands case-insensitively in the asyncio command parser
34
* Allow negative `retries` for `Retry` class to retry forever
45
* Add `items` parameter to `hset` signature

redis/asyncio/connection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ def __init__(
637637
retry_on_error = []
638638
if retry_on_timeout:
639639
retry_on_error.append(TimeoutError)
640+
retry_on_error.append(socket.timeout)
641+
retry_on_error.append(asyncio.TimeoutError)
640642
self.retry_on_error = retry_on_error
641643
if retry_on_error:
642644
if not retry:
@@ -706,7 +708,9 @@ async def connect(self):
706708
if self.is_connected:
707709
return
708710
try:
709-
await self._connect()
711+
await self.retry.call_with_retry(
712+
lambda: self._connect(), lambda error: self.disconnect(error)
713+
)
710714
except asyncio.CancelledError:
711715
raise
712716
except (socket.timeout, asyncio.TimeoutError):
@@ -816,7 +820,7 @@ async def on_connect(self) -> None:
816820
if str_if_bytes(await self.read_response()) != "OK":
817821
raise ConnectionError("Invalid Database")
818822

819-
async def disconnect(self) -> None:
823+
async def disconnect(self, *args) -> None:
820824
"""Disconnects from the Redis server"""
821825
try:
822826
async with async_timeout.timeout(self.socket_connect_timeout):

redis/asyncio/sentinel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def connect_to(self, address):
4444
if str_if_bytes(await self.read_response()) != "PONG":
4545
raise ConnectionError("PING failed")
4646

47-
async def connect(self):
47+
async def _connect_retry(self):
4848
if self._reader:
4949
return # already connected
5050
if self.connection_pool.is_master:
@@ -57,6 +57,12 @@ async def connect(self):
5757
continue
5858
raise SlaveNotFoundError # Never be here
5959

60+
async def connect(self):
61+
return await self.retry.call_with_retry(
62+
self._connect_retry,
63+
lambda error: asyncio.sleep(0),
64+
)
65+
6066
async def read_response(self, disable_decoding: bool = False):
6167
try:
6268
return await super().read_response(disable_decoding=disable_decoding)

tests/test_asyncio/test_connection.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import asyncio
2+
import socket
23
import types
4+
from unittest.mock import patch
35

46
import pytest
57

6-
from redis.asyncio.connection import PythonParser, UnixDomainSocketConnection
7-
from redis.exceptions import InvalidResponse
8+
from redis.asyncio.connection import (
9+
Connection,
10+
PythonParser,
11+
UnixDomainSocketConnection,
12+
)
13+
from redis.asyncio.retry import Retry
14+
from redis.backoff import NoBackoff
15+
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
816
from redis.utils import HIREDIS_AVAILABLE
917
from tests.conftest import skip_if_server_version_lt
1018

@@ -60,3 +68,44 @@ async def test_socket_param_regression(r):
6068
async def test_can_run_concurrent_commands(r):
6169
assert await r.ping() is True
6270
assert all(await asyncio.gather(*(r.ping() for _ in range(10))))
71+
72+
73+
async def test_connect_retry_on_timeout_error():
74+
"""Test that the _connect function is retried in case of a timeout"""
75+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 3))
76+
origin_connect = conn._connect
77+
conn._connect = mock.AsyncMock()
78+
79+
async def mock_connect():
80+
# connect only on the last retry
81+
if conn._connect.call_count <= 2:
82+
raise socket.timeout
83+
else:
84+
return await origin_connect()
85+
86+
conn._connect.side_effect = mock_connect
87+
await conn.connect()
88+
assert conn._connect.call_count == 3
89+
90+
91+
async def test_connect_without_retry_on_os_error():
92+
"""Test that the _connect function is not being retried in case of a OSError"""
93+
with patch.object(Connection, "_connect") as _connect:
94+
_connect.side_effect = OSError("")
95+
conn = Connection(retry_on_timeout=True, retry=Retry(NoBackoff(), 2))
96+
with pytest.raises(ConnectionError):
97+
await conn.connect()
98+
assert _connect.call_count == 1
99+
100+
101+
async def test_connect_timeout_error_without_retry():
102+
"""Test that the _connect function is not being retried if retry_on_timeout is
103+
set to False"""
104+
conn = Connection(retry_on_timeout=False)
105+
conn._connect = mock.AsyncMock()
106+
conn._connect.side_effect = socket.timeout
107+
108+
with pytest.raises(TimeoutError) as e:
109+
await conn.connect()
110+
assert conn._connect.call_count == 1
111+
assert str(e.value) == "Timeout connecting to server"
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import socket
2+
3+
import pytest
4+
5+
from redis.asyncio.retry import Retry
6+
from redis.asyncio.sentinel import SentinelManagedConnection
7+
from redis.backoff import NoBackoff
8+
9+
from .compat import mock
10+
11+
pytestmark = pytest.mark.asyncio
12+
13+
14+
async def test_connect_retry_on_timeout_error():
15+
"""Test that the _connect function is retried in case of a timeout"""
16+
connection_pool = mock.AsyncMock()
17+
connection_pool.get_master_address = mock.AsyncMock(
18+
return_value=("localhost", 6379)
19+
)
20+
conn = SentinelManagedConnection(
21+
retry_on_timeout=True,
22+
retry=Retry(NoBackoff(), 3),
23+
connection_pool=connection_pool,
24+
)
25+
origin_connect = conn._connect
26+
conn._connect = mock.AsyncMock()
27+
28+
async def mock_connect():
29+
# connect only on the last retry
30+
if conn._connect.call_count <= 2:
31+
raise socket.timeout
32+
else:
33+
return await origin_connect()
34+
35+
conn._connect.side_effect = mock_connect
36+
await conn.connect()
37+
assert conn._connect.call_count == 3

0 commit comments

Comments
 (0)