diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index b3dae2a27b..ec1ce5a915 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -27,6 +27,8 @@ ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse +from ..utils import format_error_message + # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 if sys.version_info >= (3, 11, 3): @@ -345,9 +347,8 @@ async def _connect(self): def _host_error(self) -> str: pass - @abstractmethod def _error_message(self, exception: BaseException) -> str: - pass + return format_error_message(self._host_error(), exception) async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" @@ -799,27 +800,6 @@ async def _connect(self): def _host_error(self) -> str: return f"{self.host}:{self.port}" - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return f"Error connecting to {host_error}. Connection reset by peer" - elif len(exception.args) == 1: - return f"Error connecting to {host_error}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {host_error}. " - f"{exception}." - ) - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -971,20 +951,6 @@ async def _connect(self): def _host_error(self) -> str: return self.path - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") diff --git a/redis/connection.py b/redis/connection.py index 728c221257..6e3b3ab081 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -39,6 +39,7 @@ HIREDIS_AVAILABLE, HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, + format_error_message, get_lib_version, str_if_bytes, ) @@ -338,9 +339,8 @@ def _connect(self): def _host_error(self): pass - @abstractmethod def _error_message(self, exception): - pass + return format_error_message(self._host_error(), exception) def on_connect(self): "Initialize the connection, authenticate and select a database" @@ -733,27 +733,6 @@ def _connect(self): def _host_error(self): return f"{self.host}:{self.port}" - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if len(exception.args) == 1: - try: - return f"Error connecting to {host_error}. \ - {exception.args[0]}." - except AttributeError: - return f"Connection Error: {exception.args[0]}" - else: - try: - return ( - f"Error {exception.args[0]} connecting to " - f"{host_error}. {exception.args[1]}." - ) - except AttributeError: - return f"Connection Error: {exception.args[0]}" - class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -930,20 +909,6 @@ def _connect(self): def _host_error(self): return self.path - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") diff --git a/redis/utils.py b/redis/utils.py index ea2eac149e..360ee54b8c 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -141,3 +141,15 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver + + +def format_error_message(host_error: str, exception: BaseException) -> str: + if not exception.args: + return f"Error connecting to {host_error}." + elif len(exception.args) == 1: + return f"Error {exception.args[0]} connecting to {host_error}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[1]}." + ) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6255ae7d6d..8f79f7d947 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -12,7 +12,12 @@ _AsyncRESPBase, ) from redis.asyncio import ConnectionPool, Redis -from redis.asyncio.connection import Connection, UnixDomainSocketConnection, parse_url +from redis.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, + parse_url, +) from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError @@ -494,18 +499,50 @@ async def test_connection_garbage_collection(request): @pytest.mark.parametrize( - "error, expected_message", + "conn, error, expected_message", [ - (OSError(), "Error connecting to localhost:6379. Connection reset by peer"), - (OSError(12), "Error connecting to localhost:6379. 12."), + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), ( + SSLConnection(), OSError(12, "Some Error"), - "Error 12 connecting to localhost:6379. [Errno 12] Some Error.", + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(), + "Error connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/redis.sock. Some Error.", ), ], ) -async def test_connect_error_message(error, expected_message): +async def test_format_error_message(conn, error, expected_message): """Test that the _error_message function formats errors correctly""" - conn = Connection() error_message = conn._error_message(error) assert error_message == expected_message + + +async def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(host="127.0.0.1", port=9999) + await redis.set("a", "b") + assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") + + +async def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(unix_socket_path="unix:///tmp/a.sock") + await redis.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + ) diff --git a/tests/test_connection.py b/tests/test_connection.py index bff249559e..69275d58c0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -296,3 +296,53 @@ def mock_disconnect(_): assert called == 1 pool.disconnect() + + +@pytest.mark.parametrize( + "conn, error, expected_message", + [ + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), + ( + SSLConnection(), + OSError(12, "Some Error"), + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(), + "Error connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/redis.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/redis.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/redis.sock. Some Error.", + ), + ], +) +def test_format_error_message(conn, error, expected_message): + """Test that the _error_message function formats errors correctly""" + error_message = conn._error_message(error) + assert error_message == expected_message + + +def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(port=9999) + redis.set("a", "b") + assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." + + +def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + redis = Redis(unix_socket_path="unix:///tmp/a.sock") + redis.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + )