Skip to content

Commit b65989d

Browse files
kristjanvalurdlunch
authored andcommitted
Optionally disable disconnects in read_response (redis#2695)
* Add regression tests and fixes for issue redis#1128 * Fix tests for resumable read_response to use "disconnect_on_error" * undo prevision fix attempts in async client and cluster * re-enable cluster test * Suggestions from code review * Add CHANGES
1 parent ee51d2d commit b65989d

File tree

10 files changed

+151
-91
lines changed

10 files changed

+151
-91
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Revert #2104, #2673, add `disconnect_on_error` option to `read_response()` (issues #2506, #2624)
12
* Fix string cleanse in Redis Graph
23
* Make PythonParser resumable in case of error (#2510)
34
* Add `timeout=None` in `SentinelConnectionManager.read_response`

redis/asyncio/client.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,20 @@ async def execute_command(self, *args, **options):
516516
command_name = args[0]
517517
conn = self.connection or await pool.get_connection(command_name, **options)
518518

519-
return await asyncio.shield(
520-
self._try_send_command_parse_response(conn, *args, **options)
521-
)
519+
if self.single_connection_client:
520+
await self._single_conn_lock.acquire()
521+
try:
522+
return await conn.retry.call_with_retry(
523+
lambda: self._send_command_parse_response(
524+
conn, command_name, *args, **options
525+
),
526+
lambda error: self._disconnect_raise(conn, error),
527+
)
528+
finally:
529+
if self.single_connection_client:
530+
self._single_conn_lock.release()
531+
if not self.connection:
532+
await pool.release(conn)
522533

523534
async def parse_response(
524535
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -757,18 +768,10 @@ async def _disconnect_raise_connect(self, conn, error):
757768
is not a TimeoutError. Otherwise, try to reconnect
758769
"""
759770
await conn.disconnect()
760-
761771
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
762772
raise error
763773
await conn.connect()
764774

765-
async def _try_execute(self, conn, command, *arg, **kwargs):
766-
try:
767-
return await command(*arg, **kwargs)
768-
except asyncio.CancelledError:
769-
await conn.disconnect()
770-
raise
771-
772775
async def _execute(self, conn, command, *args, **kwargs):
773776
"""
774777
Connect manually upon disconnection. If the Redis server is down,
@@ -777,11 +780,9 @@ async def _execute(self, conn, command, *args, **kwargs):
777780
called by the # connection to resubscribe us to any channels and
778781
patterns we were previously listening to
779782
"""
780-
return await asyncio.shield(
781-
conn.retry.call_with_retry(
782-
lambda: self._try_execute(conn, command, *args, **kwargs),
783-
lambda error: self._disconnect_raise_connect(conn, error),
784-
)
783+
return await conn.retry.call_with_retry(
784+
lambda: command(*args, **kwargs),
785+
lambda error: self._disconnect_raise_connect(conn, error),
785786
)
786787

787788
async def parse_response(self, block: bool = True, timeout: float = 0):
@@ -799,7 +800,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0):
799800
await conn.connect()
800801

801802
read_timeout = None if block else timeout
802-
response = await self._execute(conn, conn.read_response, timeout=read_timeout)
803+
response = await self._execute(
804+
conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False
805+
)
803806

804807
if conn.health_check_interval and response == self.health_check_response:
805808
# ignore the health check message as user might not expect it
@@ -1183,18 +1186,6 @@ async def _disconnect_reset_raise(self, conn, error):
11831186
await self.reset()
11841187
raise
11851188

1186-
async def _try_send_command_parse_response(self, conn, *args, **options):
1187-
try:
1188-
return await conn.retry.call_with_retry(
1189-
lambda: self._send_command_parse_response(
1190-
conn, args[0], *args, **options
1191-
),
1192-
lambda error: self._disconnect_reset_raise(conn, error),
1193-
)
1194-
except asyncio.CancelledError:
1195-
await conn.disconnect()
1196-
raise
1197-
11981189
async def immediate_execute_command(self, *args, **options):
11991190
"""
12001191
Execute a command immediately, but don't auto-retry on a
@@ -1210,8 +1201,12 @@ async def immediate_execute_command(self, *args, **options):
12101201
command_name, self.shard_hint
12111202
)
12121203
self.connection = conn
1213-
return await asyncio.shield(
1214-
self._try_send_command_parse_response(conn, *args, **options)
1204+
1205+
return await conn.retry.call_with_retry(
1206+
lambda: self._send_command_parse_response(
1207+
conn, command_name, *args, **options
1208+
),
1209+
lambda error: self._disconnect_reset_raise(conn, error),
12151210
)
12161211

12171212
def pipeline_execute_command(self, *args, **options):
@@ -1379,19 +1374,6 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
13791374
await self.reset()
13801375
raise
13811376

1382-
async def _try_execute(self, conn, execute, stack, raise_on_error):
1383-
try:
1384-
return await conn.retry.call_with_retry(
1385-
lambda: execute(conn, stack, raise_on_error),
1386-
lambda error: self._disconnect_raise_reset(conn, error),
1387-
)
1388-
except asyncio.CancelledError:
1389-
# not supposed to be possible, yet here we are
1390-
await conn.disconnect(nowait=True)
1391-
raise
1392-
finally:
1393-
await self.reset()
1394-
13951377
async def execute(self, raise_on_error: bool = True):
13961378
"""Execute all the commands in the current pipeline"""
13971379
stack = self.command_stack
@@ -1413,11 +1395,10 @@ async def execute(self, raise_on_error: bool = True):
14131395
conn = cast(Connection, conn)
14141396

14151397
try:
1416-
return await asyncio.shield(
1417-
self._try_execute(conn, execute, stack, raise_on_error)
1398+
return await conn.retry.call_with_retry(
1399+
lambda: execute(conn, stack, raise_on_error),
1400+
lambda error: self._disconnect_raise_reset(conn, error),
14181401
)
1419-
except RuntimeError:
1420-
await self.reset()
14211402
finally:
14221403
await self.reset()
14231404

redis/asyncio/cluster.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,33 +1002,12 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
10021002
await connection.send_packed_command(connection.pack_command(*args), False)
10031003

10041004
# Read response
1005-
return await asyncio.shield(
1006-
self._parse_and_release(connection, args[0], **kwargs)
1007-
)
1008-
1009-
async def _parse_and_release(self, connection, *args, **kwargs):
10101005
try:
1011-
return await self.parse_response(connection, *args, **kwargs)
1012-
except asyncio.CancelledError:
1013-
# should not be possible
1014-
await connection.disconnect(nowait=True)
1015-
raise
1006+
return await self.parse_response(connection, args[0], **kwargs)
10161007
finally:
1008+
# Release connection
10171009
self._free.append(connection)
10181010

1019-
async def _try_parse_response(self, cmd, connection, ret):
1020-
try:
1021-
cmd.result = await asyncio.shield(
1022-
self.parse_response(connection, cmd.args[0], **cmd.kwargs)
1023-
)
1024-
except asyncio.CancelledError:
1025-
await connection.disconnect(nowait=True)
1026-
raise
1027-
except Exception as e:
1028-
cmd.result = e
1029-
ret = True
1030-
return ret
1031-
10321011
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10331012
# Acquire connection
10341013
connection = self.acquire_connection()
@@ -1041,7 +1020,13 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
10411020
# Read responses
10421021
ret = False
10431022
for cmd in commands:
1044-
ret = await asyncio.shield(self._try_parse_response(cmd, connection, ret))
1023+
try:
1024+
cmd.result = await self.parse_response(
1025+
connection, cmd.args[0], **cmd.kwargs
1026+
)
1027+
except Exception as e:
1028+
cmd.result = e
1029+
ret = True
10451030

10461031
# Release connection
10471032
self._free.append(connection)

redis/asyncio/connection.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,11 @@ async def send_packed_command(
796796
raise ConnectionError(
797797
f"Error {err_no} while writing to socket. {errmsg}."
798798
) from e
799-
except Exception:
799+
except BaseException:
800+
# BaseExceptions can be raised when a socket send operation is not
801+
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
802+
# to send un-sent data. However, the send_packed_command() API
803+
# does not support it so there is no point in keeping the connection open.
800804
await self.disconnect(nowait=True)
801805
raise
802806

@@ -820,6 +824,8 @@ async def read_response(
820824
self,
821825
disable_decoding: bool = False,
822826
timeout: Optional[float] = None,
827+
*,
828+
disconnect_on_error: bool = True,
823829
):
824830
"""Read the response from a previously sent command"""
825831
read_timeout = timeout if timeout is not None else self.socket_timeout
@@ -835,22 +841,24 @@ async def read_response(
835841
)
836842
except asyncio.TimeoutError:
837843
if timeout is not None:
838-
# user requested timeout, return None
844+
# user requested timeout, return None. Operation can be retried
839845
return None
840846
# it was a self.socket_timeout error.
841-
await self.disconnect(nowait=True)
847+
if disconnect_on_error:
848+
await self.disconnect(nowait=True)
842849
raise TimeoutError(f"Timeout reading from {self.host}:{self.port}")
843850
except OSError as e:
844-
await self.disconnect(nowait=True)
851+
if disconnect_on_error:
852+
await self.disconnect(nowait=True)
845853
raise ConnectionError(
846854
f"Error while reading from {self.host}:{self.port} : {e.args}"
847855
)
848-
except asyncio.CancelledError:
849-
# need this check for 3.7, where CancelledError
850-
# is subclass of Exception, not BaseException
851-
raise
852-
except Exception:
853-
await self.disconnect(nowait=True)
856+
except BaseException:
857+
# Also by default close in case of BaseException. A lot of code
858+
# relies on this behaviour when doing Command/Response pairs.
859+
# See #1128.
860+
if disconnect_on_error:
861+
await self.disconnect(nowait=True)
854862
raise
855863

856864
if self.health_check_interval:

redis/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1539,7 +1539,7 @@ def try_read():
15391539
return None
15401540
else:
15411541
conn.connect()
1542-
return conn.read_response()
1542+
return conn.read_response(disconnect_on_error=False)
15431543

15441544
response = self._execute(conn, try_read)
15451545

redis/connection.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -816,7 +816,11 @@ def send_packed_command(self, command, check_health=True):
816816
errno = e.args[0]
817817
errmsg = e.args[1]
818818
raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.")
819-
except Exception:
819+
except BaseException:
820+
# BaseExceptions can be raised when a socket send operation is not
821+
# finished, e.g. due to a timeout. Ideally, a caller could then re-try
822+
# to send un-sent data. However, the send_packed_command() API
823+
# does not support it so there is no point in keeping the connection open.
820824
self.disconnect()
821825
raise
822826

@@ -840,23 +844,31 @@ def can_read(self, timeout=0):
840844
self.disconnect()
841845
raise ConnectionError(f"Error while reading from {host_error}: {e.args}")
842846

843-
def read_response(self, disable_decoding=False):
847+
def read_response(
848+
self, disable_decoding=False, *, disconnect_on_error: bool = True
849+
):
844850
"""Read the response from a previously sent command"""
845851

846852
host_error = self._host_error()
847853

848854
try:
849855
response = self._parser.read_response(disable_decoding=disable_decoding)
850856
except socket.timeout:
851-
self.disconnect()
857+
if disconnect_on_error:
858+
self.disconnect()
852859
raise TimeoutError(f"Timeout reading from {host_error}")
853860
except OSError as e:
854-
self.disconnect()
861+
if disconnect_on_error:
862+
self.disconnect()
855863
raise ConnectionError(
856864
f"Error while reading from {host_error}" f" : {e.args}"
857865
)
858-
except Exception:
859-
self.disconnect()
866+
except BaseException:
867+
# Also by default close in case of BaseException. A lot of code
868+
# relies on this behaviour when doing Command/Response pairs.
869+
# See #1128.
870+
if disconnect_on_error:
871+
self.disconnect()
860872
raise
861873

862874
if self.health_check_interval:

tests/test_asyncio/test_commands.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""
22
Tests async overrides of commands from their mixins
33
"""
4+
import asyncio
45
import binascii
56
import datetime
67
import re
8+
import sys
79
from string import ascii_letters
810

911
import pytest
@@ -18,6 +20,11 @@
1820
skip_unless_arch_bits,
1921
)
2022

23+
if sys.version_info >= (3, 11, 3):
24+
from asyncio import timeout as async_timeout
25+
else:
26+
from async_timeout import timeout as async_timeout
27+
2128
REDIS_6_VERSION = "5.9.0"
2229

2330

@@ -2999,6 +3006,37 @@ async def test_module_list(self, r: redis.Redis):
29993006
for x in await r.module_list():
30003007
assert isinstance(x, dict)
30013008

3009+
@pytest.mark.onlynoncluster
3010+
async def test_interrupted_command(self, r: redis.Redis):
3011+
"""
3012+
Regression test for issue #1128: An Un-handled BaseException
3013+
will leave the socket with un-read response to a previous
3014+
command.
3015+
"""
3016+
ready = asyncio.Event()
3017+
3018+
async def helper():
3019+
with pytest.raises(asyncio.CancelledError):
3020+
# blocking pop
3021+
ready.set()
3022+
await r.brpop(["nonexist"])
3023+
# If the following is not done, further Timout operations will fail,
3024+
# because the timeout won't catch its Cancelled Error if the task
3025+
# has a pending cancel. Python documentation probably should reflect this.
3026+
if sys.version_info >= (3, 11):
3027+
asyncio.current_task().uncancel()
3028+
# if all is well, we can continue. The following should not hang.
3029+
await r.set("status", "down")
3030+
3031+
task = asyncio.create_task(helper())
3032+
await ready.wait()
3033+
await asyncio.sleep(0.01)
3034+
# the task is now sleeping, lets send it an exception
3035+
task.cancel()
3036+
# If all is well, the task should finish right away, otherwise fail with Timeout
3037+
async with async_timeout(0.1):
3038+
await task
3039+
30023040

30033041
@pytest.mark.onlynoncluster
30043042
class TestBinarySave:

tests/test_asyncio/test_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def test_connection_parse_response_resume(r: redis.Redis):
137137
conn._parser._stream = MockStream(message, interrupt_every=2)
138138
for i in range(100):
139139
try:
140-
response = await conn.read_response()
140+
response = await conn.read_response(disconnect_on_error=False)
141141
break
142142
except MockStream.TestError:
143143
pass

0 commit comments

Comments
 (0)