From 5bed49a978dffe8afee66709eb8098de72624fbd Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 6 Feb 2023 02:15:15 +0200 Subject: [PATCH 01/21] Reorganizing the parsers code --- benchmarks/socket_read_size.py | 4 +- redis/asyncio/__init__.py | 2 - redis/asyncio/cluster.py | 12 +- redis/asyncio/connection.py | 371 +------------ redis/asyncio/parser.py | 94 ---- redis/cluster.py | 5 +- redis/commands/__init__.py | 2 - redis/connection.py | 498 +----------------- redis/parsers/__init__.py | 16 + redis/parsers/base.py | 227 ++++++++ .../parser.py => parsers/commands.py} | 102 +++- redis/parsers/encoders.py | 44 ++ redis/parsers/hiredis.py | 211 ++++++++ redis/parsers/resp2.py | 134 +++++ redis/parsers/socket.py | 161 ++++++ redis/typing.py | 19 +- redis/utils.py | 7 + tests/test_asyncio/conftest.py | 17 +- tests/test_asyncio/test_cluster.py | 8 +- tests/test_asyncio/test_connection.py | 12 +- tests/test_asyncio/test_pubsub.py | 4 +- tests/test_cluster.py | 2 +- tests/test_command_parser.py | 2 +- tests/test_connection_pool.py | 5 +- tests/test_pubsub.py | 4 +- whitelist.py | 1 - 26 files changed, 965 insertions(+), 999 deletions(-) delete mode 100644 redis/asyncio/parser.py create mode 100644 redis/parsers/__init__.py create mode 100644 redis/parsers/base.py rename redis/{commands/parser.py => parsers/commands.py} (63%) create mode 100644 redis/parsers/encoders.py create mode 100644 redis/parsers/hiredis.py create mode 100644 redis/parsers/resp2.py create mode 100644 redis/parsers/socket.py diff --git a/benchmarks/socket_read_size.py b/benchmarks/socket_read_size.py index 3427956ced..544c733178 100644 --- a/benchmarks/socket_read_size.py +++ b/benchmarks/socket_read_size.py @@ -1,12 +1,12 @@ from base import Benchmark -from redis.connection import HiredisParser, PythonParser +from redis.connection import PythonParser, _HiredisParser class SocketReadBenchmark(Benchmark): ARGUMENTS = ( - {"name": "parser", "values": [PythonParser, HiredisParser]}, + {"name": "parser", "values": [PythonParser, _HiredisParser]}, { "name": "value_size", "values": [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000], diff --git a/redis/asyncio/__init__.py b/redis/asyncio/__init__.py index bf90dde555..7b9508334d 100644 --- a/redis/asyncio/__init__.py +++ b/redis/asyncio/__init__.py @@ -7,7 +7,6 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.asyncio.parser import CommandsParser from redis.asyncio.sentinel import ( Sentinel, SentinelConnectionPool, @@ -38,7 +37,6 @@ "BlockingConnectionPool", "BusyLoadingError", "ChildDeadlockedError", - "CommandsParser", "Connection", "ConnectionError", "ConnectionPool", diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 5a2dffdd1d..905ece3965 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -17,15 +17,8 @@ ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import ( - Connection, - DefaultParser, - Encoder, - SSLConnection, - parse_url, -) +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis @@ -60,6 +53,7 @@ TimeoutError, TryAgainError, ) +from redis.parsers import AsyncCommandsParser, Encoder from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import dict_merge, safe_str, str_if_bytes @@ -344,7 +338,7 @@ def __init__( self.cluster_error_retry_attempts = cluster_error_retry_attempts self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 - self.commands_parser = CommandsParser() + self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2c75d4fcf1..2cc2ee7904 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -33,26 +33,17 @@ from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.typing import EncodableT, EncodedT +from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -hiredis = None -if HIREDIS_AVAILABLE: - import hiredis +from ..parsers import BaseParser, Encoder, _AsyncHiredisParser, _AsyncRESP2Parser SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -60,371 +51,19 @@ SYM_LF = b"\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - class _Sentinel(enum.Enum): sentinel = object() SENTINEL = _Sentinel.sentinel -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class _HiredisReaderArgs(TypedDict, total=False): - protocolError: Callable[[str], Exception] - replyError: Callable[[str], Exception] - encoding: Optional[str] - errors: Optional[str] - - -class Encoder: - """Encode strings to bytes-like and decode bytes-like to strings""" - - __slots__ = "encoding", "encoding_errors", "decode_responses" - - def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value: EncodableT) -> EncodedT: - """Return a bytestring or bytes-like representation of the value""" - if isinstance(value, str): - return value.encode(self.encoding, self.encoding_errors) - if isinstance(value, (bytes, memoryview)): - return value - if isinstance(value, (int, float)): - if isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. " - "Convert to a bytes, string, int or float first." - ) - return repr(value).encode() - # a value we don't know how to deal with. throw an error - typename = value.__class__.__name__ - raise DataError( - f"Invalid input of type: {typename!r}. " - "Convert to a bytes, string, int or float first." - ) - - def decode(self, value: EncodableT, force=False) -> EncodableT: - """Return a unicode string from the bytes-like representation""" - if self.decode_responses or force: - if isinstance(value, bytes): - return value.decode(self.encoding, self.encoding_errors) - if isinstance(value, memoryview): - return value.tobytes().decode(self.encoding, self.encoding_errors) - return value - - -ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - - -class BaseParser: - """Plain Python parsing class""" - - __slots__ = "_stream", "_read_size" - - EXCEPTION_CLASSES: ExceptionMappingT = { - "ERR": { - "max number of clients reached": ConnectionError, - "Client sent AUTH, but no password is set": AuthenticationError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, # noqa: E501 - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def __init__(self, socket_read_size: int): - self._stream: Optional[asyncio.StreamReader] = None - self._read_size = socket_read_size - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def parse_error(self, response: str) -> ResponseError: - """Parse an error response""" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - def on_disconnect(self): - raise NotImplementedError() - - def on_connect(self, connection: "Connection"): - raise NotImplementedError() - - async def can_read_destructive(self) -> bool: - raise NotImplementedError() - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: - raise NotImplementedError() - - -class PythonParser(BaseParser): - """Plain Python parsing class""" - - __slots__ = BaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") - - def __init__(self, socket_read_size: int): - super().__init__(socket_read_size) - self.encoder: Optional[Encoder] = None - self._buffer = b"" - self._chunks = [] - self._pos = 0 - - def _clear(self): - self._buffer = b"" - self._chunks.clear() - - def on_connect(self, connection: "Connection"): - """Called when the stream connects""" - self._stream = connection._reader - if self._stream is None: - raise RedisError("Buffer is closed.") - - self.encoder = connection.encoder - - def on_disconnect(self): - """Called when the stream disconnects""" - if self._stream is not None: - self._stream = None - self.encoder = None - self._clear() - - async def can_read_destructive(self) -> bool: - if self._buffer: - return True - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - async with async_timeout.timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read_response(self, disable_decoding: bool = False): - if self._chunks: - # augment parsing buffer with previously read data - self._buffer += b"".join(self._chunks) - self._chunks.clear() - self._pos = 0 - response = await self._read_response(disable_decoding=disable_decoding) - # Successfully parsing a response allows us to clear our parsing buffer - self._clear() - return response - - async def _read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, ResponseError, None]: - if not self._stream or not self.encoder: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - raw = await self._readline() - response: Any - byte, response = raw[:1], raw[1:] - - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - self._clear() # Successful parse - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - response = int(response) - # bulk response - elif byte == b"$": - length = int(response) - if length == -1: - return None - response = await self._read(length) - # multi-bulk response - elif byte == b"*": - length = int(response) - if length == -1: - return None - response = [ - (await self._read_response(disable_decoding)) for _ in range(length) - ] - if isinstance(response, bytes) and disable_decoding is False: - response = self.encoder.decode(response) - return response - - async def _read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - want = length + 2 - end = self._pos + want - if len(self._buffer) >= end: - result = self._buffer[self._pos : end - 2] - else: - tail = self._buffer[self._pos :] - try: - data = await self._stream.readexactly(want - len(tail)) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += want - return result - - async def _readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - found = self._buffer.find(b"\r\n", self._pos) - if found >= 0: - result = self._buffer[self._pos : found] - else: - tail = self._buffer[self._pos :] - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - result = (tail + data)[:-2] - self._chunks.append(data) - self._pos += len(result) + 2 - return result - - -class HiredisParser(BaseParser): - """Parser class for connections using Hiredis""" - - __slots__ = BaseParser.__slots__ + ("_reader",) - - def __init__(self, socket_read_size: int): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not available.") - super().__init__(socket_read_size=socket_read_size) - self._reader: Optional[hiredis.Reader] = None - - def on_connect(self, connection: "Connection"): - self._stream = connection._reader - kwargs: _HiredisReaderArgs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - } - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - kwargs["errors"] = connection.encoder.encoding_errors - - self._reader = hiredis.Reader(**kwargs) - - def on_disconnect(self): - self._stream = None - self._reader = None - - async def can_read_destructive(self): - if not self._stream or not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets(): - return True - try: - async with async_timeout.timeout(0): - return await self.read_from_socket() - except asyncio.TimeoutError: - return False - - async def read_from_socket(self): - buffer = await self._stream.read(self._read_size) - if not buffer or not isinstance(buffer, bytes): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - self._reader.feed(buffer) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - - async def read_response( - self, disable_decoding: bool = False - ) -> Union[EncodableT, List[EncodableT]]: - if not self._stream or not self._reader: - self.on_disconnect() - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - - response = self._reader.gets() - while response is False: - await self.read_from_socket() - response = self._reader.gets() - - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response -DefaultParser: Type[Union[PythonParser, HiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _AsyncHiredisParser else: - DefaultParser = PythonParser + DefaultParser = _AsyncRESP2Parser class ConnectCallbackProtocol(Protocol): diff --git a/redis/asyncio/parser.py b/redis/asyncio/parser.py deleted file mode 100644 index 5faf8f8c57..0000000000 --- a/redis/asyncio/parser.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union - -from redis.exceptions import RedisError, ResponseError - -if TYPE_CHECKING: - from redis.asyncio.cluster import ClusterNode - - -class CommandsParser: - """ - Parses Redis commands to get command keys. - - COMMAND output is used to determine key locations. - Commands that do not have a predefined key location are flagged with 'movablekeys', - and these commands' keys are determined by the command 'COMMAND GETKEYS'. - - NOTE: Due to a bug in redis<7.0, this does not work properly - for EVAL or EVALSHA when the `numkeys` arg is 0. - - issue: https://github.com/redis/redis/issues/9493 - - fix: https://github.com/redis/redis/pull/9733 - - So, don't use this with EVAL or EVALSHA. - """ - - __slots__ = ("commands", "node") - - def __init__(self) -> None: - self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} - - async def initialize(self, node: Optional["ClusterNode"] = None) -> None: - if node: - self.node = node - - commands = await self.node.execute_command("COMMAND") - for cmd, command in commands.items(): - if "movablekeys" in command["flags"]: - commands[cmd] = -1 - elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: - commands[cmd] = 0 - elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: - commands[cmd] = 1 - self.commands = {cmd.upper(): command for cmd, command in commands.items()} - - # As soon as this PR is merged into Redis, we should reimplement - # our logic to use COMMAND INFO changes to determine the key positions - # https://github.com/redis/redis/pull/8324 - async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - if len(args) < 2: - # The command has no keys in it - return None - - try: - command = self.commands[args[0]] - except KeyError: - # try to split the command name and to take only the main command - # e.g. 'memory' for 'memory usage' - args = args[0].split() + list(args[1:]) - cmd_name = args[0].upper() - if cmd_name not in self.commands: - # We'll try to reinitialize the commands cache, if the engine - # version has changed, the commands may not be current - await self.initialize() - if cmd_name not in self.commands: - raise RedisError( - f"{cmd_name} command doesn't exist in Redis commands" - ) - - command = self.commands[cmd_name] - - if command == 1: - return (args[1],) - if command == 0: - return None - if command == -1: - return await self._get_moveable_keys(*args) - - last_key_pos = command["last_key_pos"] - if last_key_pos < 0: - last_key_pos = len(args) + last_key_pos - return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] - - async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: - try: - keys = await self.node.execute_command("COMMAND GETKEYS", *args) - except ResponseError as e: - message = e.__str__() - if ( - "Invalid arguments" in message - or "The command has no key arguments" in message - ): - return None - else: - raise e - return keys diff --git a/redis/cluster.py b/redis/cluster.py index d6dc02d493..f8896372c5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -8,8 +8,8 @@ from redis.backoff import default_backoff from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan -from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands -from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url +from redis.commands import READ_COMMANDS, RedisClusterCommands +from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.exceptions import ( AskError, @@ -29,6 +29,7 @@ TryAgainError, ) from redis.lock import Lock +from redis.parsers import CommandsParser, Encoder from redis.retry import Retry from redis.utils import ( dict_merge, diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f3f08286c8..a94d9764a6 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,7 +1,6 @@ from .cluster import READ_COMMANDS, AsyncRedisClusterCommands, RedisClusterCommands from .core import AsyncCoreCommands, CoreCommands from .helpers import list_or_args -from .parser import CommandsParser from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands from .sentinel import AsyncSentinelCommands, SentinelCommands @@ -10,7 +9,6 @@ "AsyncRedisClusterCommands", "AsyncRedisModuleCommands", "AsyncSentinelCommands", - "CommandsParser", "CoreCommands", "READ_COMMANDS", "RedisClusterCommands", diff --git a/redis/connection.py b/redis/connection.py index 57f0a3a81e..43ac58dc91 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,512 +1,48 @@ import copy -import errno -import io import os import socket +import ssl import threading import weakref -from io import SEEK_END from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional, Union +from typing import Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from redis.backoff import NoBackoff -from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider -from redis.exceptions import ( +from .backoff import NoBackoff +from .credentials import CredentialProvider, UsernamePasswordCredentialProvider +from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - BusyLoadingError, ChildDeadlockedError, ConnectionError, DataError, - ExecAbortError, - InvalidResponse, - ModuleError, - NoPermissionError, - NoScriptError, - ReadOnlyError, RedisError, ResponseError, TimeoutError, ) -from redis.retry import Retry -from redis.utils import CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, str_if_bytes - -try: - import ssl - - ssl_available = True -except ImportError: - ssl_available = False - -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} - -if ssl_available: - if hasattr(ssl, "SSLWantReadError"): - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 - else: - NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 - -NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) - -if HIREDIS_AVAILABLE: - import hiredis +from .parsers import Encoder, _HiredisParser, _RESP2Parser +from .retry import Retry +from .utils import ( + CRYPTOGRAPHY_AVAILABLE, + HIREDIS_AVAILABLE, + SSL_AVAILABLE, + str_if_bytes, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" SYM_EMPTY = b"" -SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." - SENTINEL = object() -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." -NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." -MODULE_EXPORTS_DATA_TYPES_ERROR = ( - "Error unloading module: the module " - "exports one or more module-side data " - "types, can't unload" -) -# user send an AUTH cmd to a server without authorization configured -NO_AUTH_SET_ERROR = { - # Redis >= 6.0 - "AUTH called without any password " - "configured for the default user. Are you sure " - "your configuration is correct?": AuthenticationError, - # Redis < 6.0 - "Client sent AUTH, but no password is set": AuthenticationError, -} - - -class Encoder: - "Encode strings to bytes-like and decode bytes-like to strings" - - def __init__(self, encoding, encoding_errors, decode_responses): - self.encoding = encoding - self.encoding_errors = encoding_errors - self.decode_responses = decode_responses - - def encode(self, value): - "Return a bytestring or bytes-like representation of the value" - if isinstance(value, (bytes, memoryview)): - return value - elif isinstance(value, bool): - # special case bool since it is a subclass of int - raise DataError( - "Invalid input of type: 'bool'. Convert to a " - "bytes, string, int or float first." - ) - elif isinstance(value, (int, float)): - value = repr(value).encode() - elif not isinstance(value, str): - # a value we don't know how to deal with. throw an error - typename = type(value).__name__ - raise DataError( - f"Invalid input of type: '{typename}'. " - f"Convert to a bytes, string, int or float first." - ) - if isinstance(value, str): - value = value.encode(self.encoding, self.encoding_errors) - return value - - def decode(self, value, force=False): - "Return a unicode string from the bytes-like representation" - if self.decode_responses or force: - if isinstance(value, memoryview): - value = value.tobytes() - if isinstance(value, bytes): - value = value.decode(self.encoding, self.encoding_errors) - return value - - -class BaseParser: - EXCEPTION_CLASSES = { - "ERR": { - "max number of clients reached": ConnectionError, - "invalid password": AuthenticationError, - # some Redis server versions report invalid command syntax - # in lowercase - "wrong number of arguments " - "for 'auth' command": AuthenticationWrongNumberOfArgsError, - # some Redis server versions report invalid command syntax - # in uppercase - "wrong number of arguments " - "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, - MODULE_LOAD_ERROR: ModuleError, - MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, - NO_SUCH_MODULE_ERROR: ModuleError, - MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, - **NO_AUTH_SET_ERROR, - }, - "WRONGPASS": AuthenticationError, - "EXECABORT": ExecAbortError, - "LOADING": BusyLoadingError, - "NOSCRIPT": NoScriptError, - "READONLY": ReadOnlyError, - "NOAUTH": AuthenticationError, - "NOPERM": NoPermissionError, - } - - def parse_error(self, response): - "Parse an error response" - error_code = response.split(" ")[0] - if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] - exception_class = self.EXCEPTION_CLASSES[error_code] - if isinstance(exception_class, dict): - exception_class = exception_class.get(response, ResponseError) - return exception_class(response) - return ResponseError(response) - - -class SocketBuffer: - def __init__( - self, socket: socket.socket, socket_read_size: int, socket_timeout: float - ): - self._sock = socket - self.socket_read_size = socket_read_size - self.socket_timeout = socket_timeout - self._buffer = io.BytesIO() - - def unread_bytes(self) -> int: - """ - Remaining unread length of buffer - """ - pos = self._buffer.tell() - end = self._buffer.seek(0, SEEK_END) - self._buffer.seek(pos) - return end - pos - - def _read_from_socket( - self, - length: Optional[int] = None, - timeout: Union[float, object] = SENTINEL, - raise_on_timeout: Optional[bool] = True, - ) -> bool: - sock = self._sock - socket_read_size = self.socket_read_size - marker = 0 - custom_timeout = timeout is not SENTINEL - - buf = self._buffer - current_pos = buf.tell() - buf.seek(0, SEEK_END) - if custom_timeout: - sock.settimeout(timeout) - try: - while True: - data = self._sock.recv(socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - marker += data_length - - if length is not None and length > marker: - continue - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - buf.seek(current_pos) - if custom_timeout: - sock.settimeout(self.socket_timeout) - - def can_read(self, timeout: float) -> bool: - return bool(self.unread_bytes()) or self._read_from_socket( - timeout=timeout, raise_on_timeout=False - ) - - def read(self, length: int) -> bytes: - length = length + 2 # make sure to read the \r\n terminator - # BufferIO will return less than requested if buffer is short - data = self._buffer.read(length) - missing = length - len(data) - if missing: - # fill up the buffer and read the remainder - self._read_from_socket(missing) - data += self._buffer.read(missing) - return data[:-2] - - def readline(self) -> bytes: - buf = self._buffer - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - self._read_from_socket() - data += buf.readline() - - return data[:-2] - - def get_pos(self) -> int: - """ - Get current read position - """ - return self._buffer.tell() - - def rewind(self, pos: int) -> None: - """ - Rewind the buffer to a specific position, to re-start reading - """ - self._buffer.seek(pos) - - def purge(self) -> None: - """ - After a successful read, purge the read part of buffer - """ - unread = self.unread_bytes() - - # Only if we have read all of the buffer do we truncate, to - # reduce the amount of memory thrashing. This heuristic - # can be changed or removed later. - if unread > 0: - return - - if unread > 0: - # move unread data to the front - view = self._buffer.getbuffer() - view[:unread] = view[-unread:] - self._buffer.truncate(unread) - self._buffer.seek(0) - - def close(self) -> None: - try: - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None - self._sock = None - - -class PythonParser(BaseParser): - "Plain Python parsing class" - - def __init__(self, socket_read_size): - self.socket_read_size = socket_read_size - self.encoder = None - self._sock = None - self._buffer = None - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection): - "Called when the socket connects" - self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) - self.encoder = connection.encoder - - def on_disconnect(self): - "Called when the socket disconnects" - self._sock = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None - self.encoder = None - - def can_read(self, timeout): - return self._buffer and self._buffer.can_read(timeout) - - def read_response(self, disable_decoding=False): - pos = self._buffer.get_pos() - try: - result = self._read_response(disable_decoding=disable_decoding) - except BaseException: - self._buffer.rewind(pos) - raise - else: - self._buffer.purge() - return result - - def _read_response(self, disable_decoding=False): - raw = self._buffer.readline() - if not raw: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - byte, response = raw[:1], raw[1:] - - if byte not in (b"-", b"+", b":", b"$", b"*"): - raise InvalidResponse(f"Protocol Error: {raw!r}") - - # server returned an error - if byte == b"-": - response = response.decode("utf-8", errors="replace") - error = self.parse_error(response) - # if the error is a ConnectionError, raise immediately so the user - # is notified - if isinstance(error, ConnectionError): - raise error - # otherwise, we're dealing with a ResponseError that might belong - # inside a pipeline response. the connection's read_response() - # and/or the pipeline's execute() will raise this error if - # necessary, so just return the exception instance here. - return error - # single value - elif byte == b"+": - pass - # int value - elif byte == b":": - response = int(response) - # bulk response - elif byte == b"$": - length = int(response) - if length == -1: - return None - response = self._buffer.read(length) - # multi-bulk response - elif byte == b"*": - length = int(response) - if length == -1: - return None - response = [ - self._read_response(disable_decoding=disable_decoding) - for i in range(length) - ] - if isinstance(response, bytes) and disable_decoding is False: - response = self.encoder.decode(response) - return response - - -class HiredisParser(BaseParser): - "Parser class for connections using Hiredis" - - def __init__(self, socket_read_size): - if not HIREDIS_AVAILABLE: - raise RedisError("Hiredis is not installed") - self.socket_read_size = socket_read_size - self._buffer = bytearray(socket_read_size) - - def __del__(self): - try: - self.on_disconnect() - except Exception: - pass - - def on_connect(self, connection, **kwargs): - self._sock = connection._sock - self._socket_timeout = connection.socket_timeout - kwargs = { - "protocolError": InvalidResponse, - "replyError": self.parse_error, - "errors": connection.encoder.encoding_errors, - } - - if connection.encoder.decode_responses: - kwargs["encoding"] = connection.encoder.encoding - self._reader = hiredis.Reader(**kwargs) - self._next_response = False - - def on_disconnect(self): - self._sock = None - self._reader = None - self._next_response = False - - def can_read(self, timeout): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - if self._next_response is False: - self._next_response = self._reader.gets() - if self._next_response is False: - return self.read_from_socket(timeout=timeout, raise_on_timeout=False) - return True - - def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): - sock = self._sock - custom_timeout = timeout is not SENTINEL - try: - if custom_timeout: - sock.settimeout(timeout) - bufflen = self._sock.recv_into(self._buffer) - if bufflen == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - self._reader.feed(self._buffer, 0, bufflen) - # data was read from the socket and added to the buffer. - # return True to indicate that data was read. - return True - except socket.timeout: - if raise_on_timeout: - raise TimeoutError("Timeout reading from socket") - return False - except NONBLOCKING_EXCEPTIONS as ex: - # if we're in nonblocking mode and the recv raises a - # blocking error, simply return False indicating that - # there's no data to be read. otherwise raise the - # original exception. - allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) - if not raise_on_timeout and ex.errno == allowed: - return False - raise ConnectionError(f"Error while reading from socket: {ex.args}") - finally: - if custom_timeout: - sock.settimeout(self._socket_timeout) - - def read_response(self, disable_decoding=False): - if not self._reader: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - - # _next_response might be cached from a can_read() call - if self._next_response is not False: - response = self._next_response - self._next_response = False - return response - - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - - while response is False: - self.read_from_socket() - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - # if the response is a ConnectionError or the response is a list and - # the first item is a ConnectionError, raise it as something bad - # happened - if isinstance(response, ConnectionError): - raise response - elif ( - isinstance(response, list) - and response - and isinstance(response[0], ConnectionError) - ): - raise response[0] - return response - -DefaultParser: BaseParser +DefaultParser: Type[Union[_RESP2Parser, _HiredisParser]] if HIREDIS_AVAILABLE: - DefaultParser = HiredisParser + DefaultParser = _HiredisParser else: - DefaultParser = PythonParser + DefaultParser = _RESP2Parser class Connection: @@ -987,7 +523,7 @@ def __init__( Raises: RedisError """ # noqa - if not ssl_available: + if not SSL_AVAILABLE: raise RedisError("Python wasn't built with SSL support") super().__init__(**kwargs) diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py new file mode 100644 index 0000000000..68b32ed5ba --- /dev/null +++ b/redis/parsers/__init__.py @@ -0,0 +1,16 @@ +from .base import BaseParser +from .commands import AsyncCommandsParser, CommandsParser +from .encoders import Encoder +from .hiredis import _AsyncHiredisParser, _HiredisParser +from .resp2 import _AsyncRESP2Parser, _RESP2Parser + +__all__ = [ + "AsyncCommandsParser", + "_AsyncHiredisParser", + "_AsyncRESP2Parser", + "CommandsParser", + "Encoder", + "BaseParser", + "_HiredisParser", + "_RESP2Parser", +] diff --git a/redis/parsers/base.py b/redis/parsers/base.py new file mode 100644 index 0000000000..fbf4e674d5 --- /dev/null +++ b/redis/parsers/base.py @@ -0,0 +1,227 @@ +from abc import ABC +from asyncio import IncompleteReadError, StreamReader, TimeoutError +from typing import List, Optional, Union + +import async_timeout + +from ..exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ConnectionError, + ExecAbortError, + ModuleError, + NoPermissionError, + NoScriptError, + ReadOnlyError, + RedisError, + ResponseError, +) +from ..typing import EncodableT +from .encoders import Encoder +from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer + +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" +) +# user send an AUTH cmd to a server without authorization configured +NO_AUTH_SET_ERROR = { + # Redis >= 6.0 + "AUTH called without any password " + "configured for the default user. Are you sure " + "your configuration is correct?": AuthenticationError, + # Redis < 6.0 + "Client sent AUTH, but no password is set": AuthenticationError, +} + + +class BaseParser(ABC): + + EXCEPTION_CLASSES = { + "ERR": { + "max number of clients reached": ConnectionError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments " + "for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments " + "for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + **NO_AUTH_SET_ERROR, + }, + "WRONGPASS": AuthenticationError, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def parse_error(self, response): + "Parse an error response" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection): + raise NotImplementedError() + + +class _RESPBase(BaseParser): + """Base class for sync-based resp parsing""" + + def __init__(self, socket_read_size): + self.socket_read_size = socket_read_size + self.encoder = None + self._sock = None + self._buffer = None + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection): + "Called when the socket connects" + self._sock = connection._sock + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) + self.encoder = connection.encoder + + def on_disconnect(self): + "Called when the socket disconnects" + self._sock = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None + + def can_read(self, timeout): + return self._buffer and self._buffer.can_read(timeout) + + +class AsyncBaseParser(BaseParser): + """Base parsing class for the python-backed async parser""" + + __slots__ = "_stream", "_read_size" + + def __init__(self, socket_read_size: int): + self._stream: Optional[StreamReader] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + async def can_read_destructive(self) -> bool: + raise NotImplementedError() + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None, List[EncodableT]]: + raise NotImplementedError() + + +class _AsyncRESPBase(AsyncBaseParser): + """Base class for async resp parsing""" + + __slots__ = AsyncBaseParser.__slots__ + ("encoder", "_buffer", "_pos", "_chunks") + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + self._buffer = b"" + self._chunks = [] + self._pos = 0 + + def _clear(self): + self._buffer = b"" + self._chunks.clear() + + def on_connect(self, connection): + """Called when the stream connects""" + self._stream = connection._reader + if self._stream is None: + raise RedisError("Buffer is closed.") + + self.encoder = connection.encoder + + def on_disconnect(self): + """Called when the stream disconnects""" + if self._stream is not None: + self._stream = None + self.encoder = None + self._clear() + + async def can_read_destructive(self) -> bool: + if self._buffer: + return True + if self._stream is None: + raise RedisError("Buffer is closed.") + try: + async with async_timeout.timeout(0): + return await self._stream.read(1) + except TimeoutError: + return False + + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + want = length + 2 + end = self._pos + want + if len(self._buffer) >= end: + result = self._buffer[self._pos : end - 2] + else: + tail = self._buffer[self._pos :] + try: + data = await self._stream.readexactly(want - len(tail)) + except IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += want + return result + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + found = self._buffer.find(b"\r\n", self._pos) + if found >= 0: + result = self._buffer[self._pos : found] + else: + tail = self._buffer[self._pos :] + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + result = (tail + data)[:-2] + self._chunks.append(data) + self._pos += len(result) + 2 + return result diff --git a/redis/commands/parser.py b/redis/parsers/commands.py similarity index 63% rename from redis/commands/parser.py rename to redis/parsers/commands.py index 115230a9d2..496336acaa 100644 --- a/redis/commands/parser.py +++ b/redis/parsers/commands.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + from redis.exceptions import RedisError, ResponseError from redis.utils import str_if_bytes +if TYPE_CHECKING: + from redis.asyncio.cluster import ClusterNode + class CommandsParser: """ @@ -16,7 +21,7 @@ def __init__(self, redis_connection): self.initialize(redis_connection) def initialize(self, r): - commands = r.execute_command("COMMAND") + commands = r.command() uppercase_commands = [] for cmd in commands: if any(x.isupper() for x in cmd): @@ -117,14 +122,11 @@ def _get_moveable_keys(self, redis_conn, *args): So, don't use this function with EVAL or EVALSHA. """ - pieces = [] - cmd_name = args[0] # The command name should be splitted into separate arguments, # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces = pieces + cmd_name.split() - pieces = pieces + list(args[1:]) + pieces = args[0].split() + list(args[1:]) try: - keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) + keys = redis_conn.command_getkeys(*pieces) except ResponseError as e: message = e.__str__() if ( @@ -164,3 +166,91 @@ def _get_pubsub_keys(self, *args): # PUBLISH channel message keys = [args[1]] return keys + + +class AsyncCommandsParser: + """ + Parses Redis commands to get command keys. + + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with 'movablekeys', + and these commands' keys are determined by the command 'COMMAND GETKEYS'. + + NOTE: Due to a bug in redis<7.0, this does not work properly + for EVAL or EVALSHA when the `numkeys` arg is 0. + - issue: https://github.com/redis/redis/issues/9493 + - fix: https://github.com/redis/redis/pull/9733 + + So, don't use this with EVAL or EVALSHA. + """ + + __slots__ = ("commands", "node") + + def __init__(self) -> None: + self.commands: Dict[str, Union[int, Dict[str, Any]]] = {} + + async def initialize(self, node: Optional["ClusterNode"] = None) -> None: + if node: + self.node = node + + commands = await self.node.execute_command("COMMAND") + for cmd, command in commands.items(): + if "movablekeys" in command["flags"]: + commands[cmd] = -1 + elif command["first_key_pos"] == 0 and command["last_key_pos"] == 0: + commands[cmd] = 0 + elif command["first_key_pos"] == 1 and command["last_key_pos"] == 1: + commands[cmd] = 1 + self.commands = {cmd.upper(): command for cmd, command in commands.items()} + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + async def get_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + if len(args) < 2: + # The command has no keys in it + return None + + try: + command = self.commands[args[0]] + except KeyError: + # try to split the command name and to take only the main command + # e.g. 'memory' for 'memory usage' + args = args[0].split() + list(args[1:]) + cmd_name = args[0].upper() + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + await self.initialize() + if cmd_name not in self.commands: + raise RedisError( + f"{cmd_name} command doesn't exist in Redis commands" + ) + + command = self.commands[cmd_name] + + if command == 1: + return (args[1],) + if command == 0: + return None + if command == -1: + return await self._get_moveable_keys(*args) + + last_key_pos = command["last_key_pos"] + if last_key_pos < 0: + last_key_pos = len(args) + last_key_pos + return args[command["first_key_pos"] : last_key_pos + 1 : command["step_count"]] + + async def _get_moveable_keys(self, *args: Any) -> Optional[Tuple[str, ...]]: + try: + keys = await self.node.execute_command("COMMAND GETKEYS", *args) + except ResponseError as e: + message = e.__str__() + if ( + "Invalid arguments" in message + or "The command has no key arguments" in message + ): + return None + else: + raise e + return keys diff --git a/redis/parsers/encoders.py b/redis/parsers/encoders.py new file mode 100644 index 0000000000..6fdf0ad882 --- /dev/null +++ b/redis/parsers/encoders.py @@ -0,0 +1,44 @@ +from ..exceptions import DataError + + +class Encoder: + "Encode strings to bytes-like and decode bytes-like to strings" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding, encoding_errors, decode_responses): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value): + "Return a bytestring or bytes-like representation of the value" + if isinstance(value, (bytes, memoryview)): + return value + elif isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. Convert to a " + "bytes, string, int or float first." + ) + elif isinstance(value, (int, float)): + value = repr(value).encode() + elif not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = type(value).__name__ + raise DataError( + f"Invalid input of type: '{typename}'. " + f"Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + value = value.encode(self.encoding, self.encoding_errors) + return value + + def decode(self, value, force=False): + "Return a unicode string from the bytes-like representation" + if self.decode_responses or force: + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytes): + value = value.decode(self.encoding, self.encoding_errors) + return value diff --git a/redis/parsers/hiredis.py b/redis/parsers/hiredis.py new file mode 100644 index 0000000000..1d8cc5923f --- /dev/null +++ b/redis/parsers/hiredis.py @@ -0,0 +1,211 @@ +import asyncio +import socket +from typing import Callable, List, Optional, Union + +import async_timeout + +from redis.compat import TypedDict + +from ..exceptions import ConnectionError, InvalidResponse, RedisError +from ..typing import EncodableT +from ..utils import HIREDIS_AVAILABLE +from .base import AsyncBaseParser, BaseParser +from .socket import ( + NONBLOCKING_EXCEPTION_ERROR_NUMBERS, + NONBLOCKING_EXCEPTIONS, + SENTINEL, + SERVER_CLOSED_CONNECTION_ERROR, +) + + +class _HiredisReaderArgs(TypedDict, total=False): + protocolError: Callable[[str], Exception] + replyError: Callable[[str], Exception] + encoding: Optional[str] + errors: Optional[str] + + +class _HiredisParser(BaseParser): + "Parser class for connections using Hiredis" + + def __init__(self, socket_read_size): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not installed") + self.socket_read_size = socket_read_size + self._buffer = bytearray(socket_read_size) + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def on_connect(self, connection, **kwargs): + import hiredis + + self._sock = connection._sock + self._socket_timeout = connection.socket_timeout + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + "errors": connection.encoder.encoding_errors, + } + + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + + def on_disconnect(self): + self._sock = None + self._reader = None + self._next_response = False + + def can_read(self, timeout): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + sock = self._sock + custom_timeout = timeout is not SENTINEL + try: + if custom_timeout: + sock.settimeout(timeout) + bufflen = self._sock.recv_into(self._buffer) + if bufflen == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + self._reader.feed(self._buffer, 0, bufflen) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + if custom_timeout: + sock.settimeout(self._socket_timeout) + + def read_response(self, disable_decoding=False): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + + while response is False: + self.read_from_socket() + if disable_decoding: + response = self._reader.gets(False) + else: + response = self._reader.gets() + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +class _AsyncHiredisParser(AsyncBaseParser): + """Async implementation of parser class for connections using Hiredis""" + + __slots__ = AsyncBaseParser.__slots__ + ("_reader",) + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._reader = None + + def on_connect(self, connection): + import hiredis + + self._stream = connection._reader + kwargs: _HiredisReaderArgs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs["encoding"] = connection.encoder.encoding + kwargs["errors"] = connection.encoder.encoding_errors + + self._reader = hiredis.Reader(**kwargs) + + def on_disconnect(self): + self._stream = None + self._reader = None + + async def can_read_destructive(self): + if not self._stream or not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + if self._reader.gets(): + return True + try: + async with async_timeout.timeout(0): + return await self.read_from_socket() + except asyncio.TimeoutError: + return False + + async def read_from_socket(self): + buffer = await self._stream.read(self._read_size) + if not buffer or not isinstance(buffer, bytes): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + + async def read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, List[EncodableT]]: + if not self._stream or not self._reader: + self.on_disconnect() + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response diff --git a/redis/parsers/resp2.py b/redis/parsers/resp2.py new file mode 100644 index 0000000000..63fd67ab5b --- /dev/null +++ b/redis/parsers/resp2.py @@ -0,0 +1,134 @@ +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP2Parser(_RESPBase): + """RESP2 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = self._buffer.read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(length) + ] + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP2Parser(_AsyncRESPBase): + """Async class for the RESP2 protocol""" + + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = await self._read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [ + (await self._read_response(disable_decoding)) for _ in range(length) + ] + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py new file mode 100644 index 0000000000..5b3681bd19 --- /dev/null +++ b/redis/parsers/socket.py @@ -0,0 +1,161 @@ +import errno +import io +import socket +from io import SEEK_END +from typing import Optional, Union + +from ..connection import SYM_CRLF +from ..exceptions import ConnectionError, TimeoutError +from ..utils import SSL_AVAILABLE + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} + +if SSL_AVAILABLE: + import ssl + + if hasattr(ssl, "SSLWantReadError"): + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantReadError] = 2 + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLWantWriteError] = 2 + else: + NONBLOCKING_EXCEPTION_ERROR_NUMBERS[ssl.SSLError] = 2 + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." +SENTINEL = object() + + +class SocketBuffer: + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): + self._sock = socket + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos + + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: + sock = self._sock + socket_read_size = self.socket_read_size + marker = 0 + custom_timeout = timeout is not SENTINEL + + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) + try: + while True: + data = self._sock.recv(socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + marker += data_length + + if length is not None and length > marker: + continue + return True + except socket.timeout: + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + finally: + buf.seek(current_pos) + if custom_timeout: + sock.settimeout(self.socket_timeout) + + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( + timeout=timeout, raise_on_timeout=False + ) + + def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # BufferIO will return less than requested if buffer is short + data = self._buffer.read(length) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) + return data[:-2] + + def readline(self) -> bytes: + buf = self._buffer + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + self._read_from_socket() + data += buf.readline() + + return data[:-2] + + def get_pos(self) -> int: + """ + Get current read position + """ + return self._buffer.tell() + + def rewind(self, pos: int) -> None: + """ + Rewind the buffer to a specific position, to re-start reading + """ + self._buffer.seek(pos) + + def purge(self) -> None: + """ + After a successful read, purge the read part of buffer + """ + unread = self.unread_bytes() + + # Only if we have read all of the buffer do we truncate, to + # reduce the amount of memory thrashing. This heuristic + # can be changed or removed later. + if unread > 0: + return + + if unread > 0: + # move unread data to the front + view = self._buffer.getbuffer() + view[:unread] = view[-unread:] + self._buffer.truncate(unread) + self._buffer.seek(0) + + def close(self) -> None: + try: + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._sock = None diff --git a/redis/typing.py b/redis/typing.py index 8504c7de0c..22d7775ab9 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -1,14 +1,23 @@ # from __future__ import annotations from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Awaitable, Iterable, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Iterable, + Mapping, + Type, + TypeVar, + Union, +) from redis.compat import Protocol if TYPE_CHECKING: from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool - from redis.asyncio.connection import Encoder as AsyncEncoder - from redis.connection import ConnectionPool, Encoder + from redis.connection import ConnectionPool + from redis.parsers import Encoder Number = Union[int, float] @@ -39,6 +48,8 @@ AnyFieldT = TypeVar("AnyFieldT", bytes, str, memoryview) AnyChannelT = TypeVar("AnyChannelT", bytes, str, memoryview) +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] + class CommandsProtocol(Protocol): connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] @@ -48,7 +59,7 @@ def execute_command(self, *args, **options): class ClusterCommandsProtocol(CommandsProtocol): - encoder: Union["AsyncEncoder", "Encoder"] + encoder: Encoder def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... diff --git a/redis/utils.py b/redis/utils.py index 693d4e64b5..1171ed0aba 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -10,6 +10,13 @@ except ImportError: HIREDIS_AVAILABLE = False +try: + import ssl # noqa + + SSL_AVAILABLE = True +except ImportError: + SSL_AVAILABLE = False + try: import cryptography # noqa diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 6982cc840a..28a6f0626f 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -9,14 +9,11 @@ import redis.asyncio as redis from redis.asyncio.client import Monitor -from redis.asyncio.connection import ( - HIREDIS_AVAILABLE, - HiredisParser, - PythonParser, - parse_url, -) +from redis.asyncio.connection import parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff +from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser +from redis.utils import HIREDIS_AVAILABLE from tests.conftest import REDIS_INFO from .compat import mock @@ -32,14 +29,14 @@ async def _get_info(redis_url): @pytest_asyncio.fixture( params=[ pytest.param( - (True, PythonParser), + (True, _AsyncRESP2Parser), marks=pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', reason="cluster mode enabled" ), ), - (False, PythonParser), + (False, _AsyncRESP2Parser), pytest.param( - (True, HiredisParser), + (True, _AsyncHiredisParser), marks=[ pytest.mark.skipif( 'config.REDIS_INFO["cluster_enabled"]', @@ -51,7 +48,7 @@ async def _get_info(redis_url): ], ), pytest.param( - (False, HiredisParser), + (False, _AsyncHiredisParser), marks=pytest.mark.skipif( not HIREDIS_AVAILABLE, reason="hiredis is not installed" ), diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 13e5e26ae3..e8cc1955cf 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -12,7 +12,6 @@ from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster from redis.asyncio.connection import Connection, SSLConnection -from redis.asyncio.parser import CommandsParser from redis.asyncio.retry import Retry from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import PIPELINE_BLOCKED_COMMANDS, PRIMARY, REPLICA, get_node_name @@ -29,6 +28,7 @@ RedisError, ResponseError, ) +from redis.parsers import AsyncCommandsParser from redis.utils import str_if_bytes from tests.conftest import ( skip_if_redis_enterprise, @@ -99,7 +99,7 @@ async def execute_command(*_args, **_kwargs): execute_command_mock.side_effect = execute_command with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -549,7 +549,7 @@ def map_7007(self): mocks["send_packed_command"].return_value = "MOCK_OK" mocks["connect"].return_value = None with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: @@ -2341,7 +2341,7 @@ async def mocked_execute_command(self, *args, **kwargs): assert "Redis Cluster cannot be connected" in str(e.value) with mock.patch.object( - CommandsParser, "initialize", autospec=True + AsyncCommandsParser, "initialize", autospec=True ) as cmd_parser_initialize: def cmd_init_mock(self, r: ClusterNode) -> None: diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8e4fdac309..f8464a7ed5 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -7,15 +7,11 @@ import redis from redis.asyncio import Redis -from redis.asyncio.connection import ( - BaseParser, - Connection, - PythonParser, - UnixDomainSocketConnection, -) +from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.parsers import _AsyncRESP2Parser from tests.conftest import skip_if_server_version_lt from .compat import mock @@ -29,11 +25,11 @@ async def test_invalid_response(create_redis): raw = b"x" fake_stream = MockStream(raw + b"\r\n") - parser: BaseParser = r.connection._parser + parser: _AsyncRESP2Parser = r.connection._parser with mock.patch.object(parser, "_stream", fake_stream): with pytest.raises(InvalidResponse) as cm: await parser.read_response() - if isinstance(parser, PythonParser): + if isinstance(parser, _AsyncRESP2Parser): assert str(cm.value) == f"Protocol Error: {raw!r}" else: assert ( diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index c2a9130e83..e8a9fac4f0 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -991,9 +991,9 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.asyncio.connection.PythonParser.read_response") as mock1: + with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1: mock1.side_effect = BaseException("boom") - with patch("redis.asyncio.connection.HiredisParser.read_response") as mock2: + with patch("redis.parsers._AsyncHiredisParser.read_response") as mock2: mock2.side_effect = BaseException("boom") with pytest.raises(BaseException): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 1bf57a357c..3dc23bb852 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -17,7 +17,6 @@ RedisCluster, get_node_name, ) -from redis.commands import CommandsParser from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot from redis.exceptions import ( @@ -32,6 +31,7 @@ ResponseError, TimeoutError, ) +from redis.parsers import CommandsParser from redis.retry import Retry from redis.utils import str_if_bytes from tests.test_pubsub import wait_for_message diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index 6c3ede9cdf..b2a2268f85 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,6 +1,6 @@ import pytest -from redis.commands import CommandsParser +from redis.parsers import CommandsParser from .conftest import skip_if_redis_enterprise, skip_if_server_version_lt diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index e8a42692a1..ba9fef3089 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,7 +7,8 @@ import pytest import redis -from redis.connection import ssl_available, to_bool +from redis.connection import to_bool +from redis.utils import SSL_AVAILABLE from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt from .test_pubsub import wait_for_message @@ -425,7 +426,7 @@ class MyConnection(redis.UnixDomainSocketConnection): assert pool.connection_class == MyConnection -@pytest.mark.skipif(not ssl_available, reason="SSL not installed") +@pytest.mark.skipif(not SSL_AVAILABLE, reason="SSL not installed") class TestSSLConnectionURLParsing: def test_host(self): pool = redis.ConnectionPool.from_url("rediss://my.host") diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5d86934de6..48c0f3ac47 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -767,9 +767,9 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.connection.PythonParser.read_response") as mock1: + with patch("redis.parsers._RESP2Parser.read_response") as mock1: mock1.side_effect = BaseException("boom") - with patch("redis.connection.HiredisParser.read_response") as mock2: + with patch("redis.parsers._HiredisParser.read_response") as mock2: mock2.side_effect = BaseException("boom") with pytest.raises(BaseException): diff --git a/whitelist.py b/whitelist.py index 8c9cee3c29..29cd529e4d 100644 --- a/whitelist.py +++ b/whitelist.py @@ -14,6 +14,5 @@ exc_value # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) traceback # unused variable (/data/repos/redis/redis-py/redis/asyncio/utils.py:26) AsyncConnectionPool # unused import (//data/repos/redis/redis-py/redis/typing.py:9) -AsyncEncoder # unused import (//data/repos/redis/redis-py/redis/typing.py:10) AsyncRedis # unused import (//data/repos/redis/redis-py/redis/commands/core.py:49) TargetNodesT # unused import (//data/repos/redis/redis-py/redis/commands/cluster.py:46) From 4d990f727185cbd5a0eb16b109cfa93c13ce9432 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 6 Feb 2023 02:26:27 +0200 Subject: [PATCH 02/21] fix build package --- redis/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/typing.py b/redis/typing.py index 22d7775ab9..7c5908ff0c 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -59,7 +59,7 @@ def execute_command(self, *args, **options): class ClusterCommandsProtocol(CommandsProtocol): - encoder: Encoder + encoder: "Encoder" def execute_command(self, *args, **options) -> Union[Any, Awaitable]: ... From 82f1a9ae2fc27de3bd2dbdcefcf439fd42034a99 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 6 Feb 2023 02:58:12 +0200 Subject: [PATCH 03/21] fix imports --- redis/parsers/socket.py | 3 ++- tests/test_connection.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/redis/parsers/socket.py b/redis/parsers/socket.py index 5b3681bd19..8147243bba 100644 --- a/redis/parsers/socket.py +++ b/redis/parsers/socket.py @@ -4,7 +4,6 @@ from io import SEEK_END from typing import Optional, Union -from ..connection import SYM_CRLF from ..exceptions import ConnectionError, TimeoutError from ..utils import SSL_AVAILABLE @@ -24,6 +23,8 @@ SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." SENTINEL = object() +SYM_CRLF = b"\r\n" + class SocketBuffer: def __init__( diff --git a/tests/test_connection.py b/tests/test_connection.py index e0b53cdf37..7ff13da8a3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,8 +7,9 @@ import redis from redis.backoff import NoBackoff -from redis.connection import Connection, HiredisParser, PythonParser +from redis.connection import Connection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.parsers import _HiredisParser, _RESP2Parser from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -128,7 +129,7 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", [_RESP2Parser, _HiredisParser], ids=["PythonParser", "HiredisParser"] ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ @@ -136,7 +137,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): be that PythonParser or HiredisParser, can be interrupted at IO time and then resume parsing. """ - if parser_class is HiredisParser and not HIREDIS_AVAILABLE: + if parser_class is _HiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available)") args = dict(r.connection_pool.connection_kwargs) args["parser_class"] = parser_class @@ -148,7 +149,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): ) mock_socket = MockSocket(message, interrupt_every=2) - if isinstance(conn._parser, PythonParser): + if isinstance(conn._parser, _RESP2Parser): conn._parser._buffer._sock = mock_socket else: conn._parser._sock = mock_socket From dac660471210d9290b5f84b4e96a54ecbcd53bd6 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 6 Feb 2023 03:02:35 +0200 Subject: [PATCH 04/21] fix flake8 --- tests/test_connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7ff13da8a3..b96c076375 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -129,7 +129,9 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [_RESP2Parser, _HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_RESP2Parser, _HiredisParser], + ids=["PythonParser", "HiredisParser"], ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ From 5c8b66d934c4f67b4507cd61d5fe3afaabbe49c0 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 7 Feb 2023 12:09:56 +0200 Subject: [PATCH 05/21] add resp to Connection class --- redis/connection.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/redis/connection.py b/redis/connection.py index 43ac58dc91..ae0b051573 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -72,6 +72,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + resp: Optional[int] = None, ): """ Initialize a new Connection. @@ -126,6 +127,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + self.resp = resp def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -288,6 +290,11 @@ def on_connect(self): if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + if self.resp: + self.send_command("HELLO", self.resp) + self.read_response() + # if a client_name is given, set it if self.client_name: self.send_command("CLIENT", "SETNAME", self.client_name) From 124af35bbc520aa11e203055983a8a208562d0d9 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Sun, 26 Feb 2023 05:59:46 +0200 Subject: [PATCH 06/21] core commands --- redis/client.py | 48 ++++ redis/connection.py | 12 +- tests/conftest.py | 7 +- tests/test_commands.py | 622 +++++++++++++++++++++++++++++------------ 4 files changed, 500 insertions(+), 189 deletions(-) diff --git a/redis/client.py b/redis/client.py index 1a9b96b83d..55ae33fc64 100755 --- a/redis/client.py +++ b/redis/client.py @@ -340,6 +340,12 @@ def parse_xread(response): return [[r[0], parse_stream_list(r[1])] for r in response] +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + def parse_xpending(response, **options): if options.get("parse_detail", False): return parse_xpending_range(response) @@ -841,6 +847,43 @@ class AbstractRedis: "ZMSCORE": parse_zmscore, } + RESP3_RESPONSE_CALLBACKS = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} + for x in r + ] + if isinstance(r, list) + else bool_ok(r), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "XINFO STREAM": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + }, + "MEMORY STATS": lambda r: { + str_if_bytes(key): value for key, value in r.items() + }, + } + class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): """ @@ -942,6 +985,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Redis client. @@ -990,6 +1034,7 @@ def __init__( "client_name": client_name, "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, + "protocol": protocol, } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1037,6 +1082,9 @@ def __init__( self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + if self.connection_pool.connection_kwargs.get("protocol") == "3": + self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" diff --git a/redis/connection.py b/redis/connection.py index ae0b051573..0fe9389f72 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -72,7 +72,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, - resp: Optional[int] = None, + protocol: Optional[int] = 2, ): """ Initialize a new Connection. @@ -127,7 +127,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 - self.resp = resp + self.protocol = protocol def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -291,9 +291,11 @@ def on_connect(self): raise AuthenticationError("Invalid Username or Password") # if resp version is specified, switch to it - if self.resp: - self.send_command("HELLO", self.resp) - self.read_response() + if self.protocol != 2: + self.send_command("HELLO", self.protocol) + response = self.read_response() + if response[b"proto"] != int(self.protocol): + raise ConnectionError("Invalid RESP version") # if a client_name is given, set it if self.client_name: diff --git a/tests/conftest.py b/tests/conftest.py index 27dcc741a7..557059230d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from redis.retry import Retry REDIS_INFO = {} -default_redis_url = "redis://localhost:6379/9" +default_redis_url = "redis://localhost:6379/0" default_redismod_url = "redis://localhost:36379" default_redis_unstable_url = "redis://localhost:6378" @@ -472,3 +472,8 @@ def wait_for_command(client, monitor, command, key=None): return monitor_response if key in monitor_response["command"]: return None + + +def is_resp2_connection(r): + protocol = r.connection_pool.connection_kwargs.get("protocol") + return protocol == "2" or protocol is None diff --git a/tests/test_commands.py b/tests/test_commands.py index 94249e9419..1af69c83c0 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -13,6 +13,7 @@ from .conftest import ( _get_client, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_gte, skip_if_server_version_lt, @@ -380,7 +381,10 @@ def teardown(): assert len(r.acl_log()) == 2 assert len(r.acl_log(count=1)) == 1 assert isinstance(r.acl_log()[0], dict) - assert "client-info" in r.acl_log(count=1)[0] + if is_resp2_connection(r): + assert "client-info" in r.acl_log(count=1)[0] + else: + assert "client-info" in r.acl_log(count=1)[0].keys() assert r.acl_log_reset() @skip_if_server_version_lt("6.0.0") @@ -1535,7 +1539,10 @@ def test_hrandfield(self, r): assert r.hrandfield("key") is not None assert len(r.hrandfield("key", 2)) == 2 # with values - assert len(r.hrandfield("key", 2, True)) == 4 + if is_resp2_connection(r): + assert len(r.hrandfield("key", 2, True)) == 4 + else: + assert len(r.hrandfield("key", 2, True)) == 2 # without duplications assert len(r.hrandfield("key", 10)) == 5 # with duplications @@ -1688,17 +1695,30 @@ def test_stralgo_lcs(self, r): assert r.stralgo("LCS", key1, key2, specific_argument="keys") == res # test other labels assert r.stralgo("LCS", value1, value2, len=True) == len(res) - assert r.stralgo("LCS", value1, value2, idx=True) == { - "len": len(res), - "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], - } - assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { - "len": len(res), - "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], - } - assert r.stralgo( - "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True - ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + if is_resp2_connection(r): + assert r.stralgo("LCS", value1, value2, idx=True) == { + "len": len(res), + "matches": [[(4, 7), (5, 8)], [(2, 3), (0, 1)]], + } + assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { + "len": len(res), + "matches": [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]], + } + assert r.stralgo( + "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True + ) == {"len": len(res), "matches": [[4, (4, 7), (5, 8)]]} + else: + assert r.stralgo("LCS", value1, value2, idx=True) == { + "len": len(res), + "matches": [[[4, 7], [5, 8]], [[2, 3], [0, 1]]], + } + assert r.stralgo("LCS", value1, value2, idx=True, withmatchlen=True) == { + "len": len(res), + "matches": [[[4, 7], [5, 8], 4], [[2, 3], [0, 1], 2]], + } + assert r.stralgo( + "LCS", value1, value2, idx=True, minmatchlen=4, withmatchlen=True + ) == {"len": len(res), "matches": [[[4, 7], [5, 8], 4]]} @skip_if_server_version_lt("6.0.0") @skip_if_server_version_gte("7.0.0") @@ -2147,8 +2167,10 @@ def test_spop_multi_value(self, r): for value in values: assert value in s - - assert r.spop("a", 1) == list(set(s) - set(values)) + if is_resp2_connection(r): + assert r.spop("a", 1) == list(set(s) - set(values)) + else: + assert r.spop("a", 1) == set(s) - set(values) def test_srandmember(self, r): s = [b"1", b"2", b"3"] @@ -2199,11 +2221,18 @@ def test_script_debug(self, r): def test_zadd(self, r): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} r.zadd("a", mapping) - assert r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + (b"a3", 3.0), + ] + else: + assert r.zrange("a", 0, -1, withscores=True) == [ + [b"a1", 1.0], + [b"a2", 2.0], + [b"a3", 3.0], + ] # error cases with pytest.raises(exceptions.DataError): @@ -2220,17 +2249,32 @@ def test_zadd(self, r): def test_zadd_nx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + else: + assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] def test_zadd_xx(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + else: + assert r.zrange("a", 0, -1, withscores=True) == [[b"a1", 99.0]] def test_zadd_ch(self, r): assert r.zadd("a", {"a1": 1}) == 1 assert r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert r.zrange("a", 0, -1, withscores=True) == [(b"a2", 2.0), (b"a1", 99.0)] + if is_resp2_connection(r): + assert r.zrange("a", 0, -1, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 99.0), + ] + else: + assert r.zrange("a", 0, -1, withscores=True) == [ + [b"a2", 2.0], + [b"a1", 99.0], + ] def test_zadd_incr(self, r): assert r.zadd("a", {"a1": 1}) == 1 @@ -2278,7 +2322,10 @@ def test_zdiff(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiff(["a", "b"]) == [b"a3"] - assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + if is_resp2_connection(r): + assert r.zdiff(["a", "b"], withscores=True) == [b"a3", b"3"] + else: + assert r.zdiff(["a", "b"], withscores=True) == [[b"a3", 3.0]] @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.2.0") @@ -2287,7 +2334,10 @@ def test_zdiffstore(self, r): r.zadd("b", {"a1": 1, "a2": 2}) assert r.zdiffstore("out", ["a", "b"]) assert r.zrange("out", 0, -1) == [b"a3"] - assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + if is_resp2_connection(r): + assert r.zrange("out", 0, -1, withscores=True) == [(b"a3", 3.0)] + else: + assert r.zrange("out", 0, -1, withscores=True) == [[b"a3", 3.0]] def test_zincrby(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2312,23 +2362,48 @@ def test_zinter(self, r): # invalid aggregation with pytest.raises(exceptions.DataError): r.zinter(["a", "b", "c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["a", "b", "c"], withscores=True) == [(b"a3", 8), (b"a1", 9)] - # aggregate with MAX - assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a3", 5), - (b"a1", 6), - ] - # aggregate with MIN - assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a3", 1), - ] - # with weights - assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + # aggregate with SUM + assert r.zinter(["a", "b", "c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a3", 5), + (b"a1", 6), + ] + # aggregate with MIN + assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a3", 1), + ] + # with weights + assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] + else: + # aggregate with SUM + assert r.zinter(["a", "b", "c"], withscores=True) == [ + [b"a3", 8], + [b"a1", 9], + ] + # aggregate with MAX + assert r.zinter(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + [b"a3", 5], + [b"a1", 6], + ] + # aggregate with MIN + assert r.zinter(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + [b"a1", 1], + [b"a3", 1], + ] + # with weights + assert r.zinter({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + [b"a3", 20], + [b"a1", 23], + ] @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -2345,7 +2420,10 @@ def test_zinterstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"]) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 8], [b"a1", 9]] @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): @@ -2353,7 +2431,10 @@ def test_zinterstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 5], [b"a1", 6]] @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): @@ -2361,7 +2442,10 @@ def test_zinterstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a1", 1], [b"a3", 3]] @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): @@ -2369,23 +2453,34 @@ def test_zinterstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + else: + assert r.zrange("d", 0, -1, withscores=True) == [[b"a3", 20], [b"a1", 23]] @skip_if_server_version_lt("4.9.0") def test_zpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmax("a") == [(b"a3", 3)] - - # with count - assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + if is_resp2_connection(r): + assert r.zpopmax("a") == [(b"a3", 3)] + # with count + assert r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + else: + assert r.zpopmax("a") == [b"a3", 3.0] + # with count + assert r.zpopmax("a", count=2) == [[b"a2", 2], [b"a1", 1]] @skip_if_server_version_lt("4.9.0") def test_zpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert r.zpopmin("a") == [(b"a1", 1)] - - # with count - assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zpopmin("a") == [(b"a1", 1)] + # with count + assert r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + else: + assert r.zpopmin("a") == [b"a1", 1.0] + # with count + assert r.zpopmin("a", count=2) == [[b"a2", 2], [b"a3", 3]] @skip_if_server_version_lt("6.2.0") def test_zrandemember(self, r): @@ -2393,7 +2488,10 @@ def test_zrandemember(self, r): assert r.zrandmember("a") is not None assert len(r.zrandmember("a", 2)) == 2 # with scores - assert len(r.zrandmember("a", 2, True)) == 4 + if is_resp2_connection(r): + assert len(r.zrandmember("a", 2, True)) == 4 + else: + assert len(r.zrandmember("a", 2, True)) == 2 # without duplications assert len(r.zrandmember("a", 10)) == 5 # with duplications @@ -2457,14 +2555,18 @@ def test_zrange(self, r): assert r.zrange("a", 0, 2, desc=True) == [b"a3", b"a2", b"a1"] # withscores - assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] - assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] - - # custom score function - assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + if is_resp2_connection(r): + assert r.zrange("a", 0, 1, withscores=True) == [(b"a1", 1.0), (b"a2", 2.0)] + assert r.zrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a3", 3.0)] + + # custom score function + assert r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a1", 1), + (b"a2", 2), + ] + else: + assert r.zrange("a", 0, 1, withscores=True) == [[b"a1", 1.0], [b"a2", 2.0]] + assert r.zrange("a", 1, 2, withscores=True) == [[b"a2", 2.0], [b"a3", 3.0]] def test_zrange_errors(self, r): with pytest.raises(exceptions.DataError): @@ -2496,14 +2598,25 @@ def test_zrange_params(self, r): b"a3", b"a2", ] - assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrange( - "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + if is_resp2_connection(r): + assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + + else: + assert r.zrange("a", 2, 4, byscore=True, withscores=True) == [ + [b"a2", 2.0], + [b"a3", 3.0], + [b"a4", 4.0], + ] + assert r.zrange( + "a", 4, 2, desc=True, byscore=True, withscores=True, score_cast_func=int + ) == [[b"a4", 4], [b"a3", 3], [b"a2", 2]] # rev assert r.zrange("a", 0, 1, desc=True) == [b"a5", b"a4"] @@ -2516,7 +2629,10 @@ def test_zrangestore(self, r): assert r.zrange("b", 0, -1) == [b"a1", b"a2"] assert r.zrangestore("b", "a", 1, 2) assert r.zrange("b", 0, -1) == [b"a2", b"a3"] - assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + if is_resp2_connection(r): + assert r.zrange("b", 0, -1, withscores=True) == [(b"a2", 2), (b"a3", 3)] + else: + assert r.zrange("b", 0, -1, withscores=True) == [[b"a2", 2], [b"a3", 3]] # reversed order assert r.zrangestore("b", "a", 1, 2, desc=True) assert r.zrange("b", 0, -1) == [b"a1", b"a2"] @@ -2551,16 +2667,28 @@ def test_zrangebyscore(self, r): # slicing with start/num assert r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] - assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + if is_resp2_connection(r): + assert r.zrangebyscore("a", 2, 4, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + else: + assert r.zrangebyscore("a", 2, 4, withscores=True) == [ + [b"a2", 2.0], + [b"a3", 3.0], + [b"a4", 4.0], + ] + assert r.zrangebyscore("a", 2, 4, withscores=True, score_cast_func=int) == [ + [b"a2", 2], + [b"a3", 3], + [b"a4", 4], + ] def test_zrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2607,33 +2735,61 @@ def test_zrevrange(self, r): assert r.zrevrange("a", 0, 1) == [b"a3", b"a2"] assert r.zrevrange("a", 1, 2) == [b"a2", b"a1"] - # withscores - assert r.zrevrange("a", 0, 1, withscores=True) == [(b"a3", 3.0), (b"a2", 2.0)] - assert r.zrevrange("a", 1, 2, withscores=True) == [(b"a2", 2.0), (b"a1", 1.0)] + if is_resp2_connection(r): + # withscores + assert r.zrevrange("a", 0, 1, withscores=True) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + assert r.zrevrange("a", 1, 2, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 1.0), + ] - # custom score function - assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + # custom score function + assert r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + else: + # withscores + assert r.zrevrange("a", 0, 1, withscores=True) == [ + [b"a3", 3.0], + [b"a2", 2.0], + ] + assert r.zrevrange("a", 1, 2, withscores=True) == [ + [b"a2", 2.0], + [b"a1", 1.0], + ] def test_zrevrangebyscore(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) assert r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] # slicing with start/num assert r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] - # withscores - assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] - # custom score function - assert r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) == [ - (b"a4", 4), - (b"a3", 3), - (b"a2", 2), - ] + + if is_resp2_connection(r): + # withscores + assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + (b"a4", 4.0), + (b"a3", 3.0), + (b"a2", 2.0), + ] + # custom score function + assert r.zrevrangebyscore( + "a", 4, 2, withscores=True, score_cast_func=int + ) == [ + (b"a4", 4), + (b"a3", 3), + (b"a2", 2), + ] + else: + # withscores + assert r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + [b"a4", 4.0], + [b"a3", 3.0], + [b"a2", 2.0], + ] def test_zrevrank(self, r): r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -2655,33 +2811,63 @@ def test_zunion(self, r): r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) # sum assert r.zunion(["a", "b", "c"]) == [b"a2", b"a4", b"a3", b"a1"] - assert r.zunion(["a", "b", "c"], withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] - # max - assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] - # min - assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ - (b"a1", 1), - (b"a2", 1), - (b"a3", 1), - (b"a4", 4), - ] - # with weight - assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + + if is_resp2_connection(r): + assert r.zunion(["a", "b", "c"], withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + # max + assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + # min + assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + (b"a1", 1), + (b"a2", 1), + (b"a3", 1), + (b"a4", 4), + ] + # with weight + assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + else: + assert r.zunion(["a", "b", "c"], withscores=True) == [ + [b"a2", 3], + [b"a4", 4], + [b"a3", 8], + [b"a1", 9], + ] + # max + assert r.zunion(["a", "b", "c"], aggregate="MAX", withscores=True) == [ + [b"a2", 2], + [b"a4", 4], + [b"a3", 5], + [b"a1", 6], + ] + # min + assert r.zunion(["a", "b", "c"], aggregate="MIN", withscores=True) == [ + [b"a1", 1], + [b"a2", 1], + [b"a3", 1], + [b"a4", 4], + ] + # with weight + assert r.zunion({"a": 1, "b": 2, "c": 3}, withscores=True) == [ + [b"a2", 5], + [b"a4", 12], + [b"a3", 20], + [b"a1", 23], + ] @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): @@ -2689,12 +2875,21 @@ def test_zunionstore_sum(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"]) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 3], + [b"a4", 4], + [b"a3", 8], + [b"a1", 9], + ] @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): @@ -2702,12 +2897,20 @@ def test_zunionstore_max(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 2], + [b"a4", 4], + [b"a3", 5], + [b"a1", 6], + ] @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): @@ -2715,12 +2918,20 @@ def test_zunionstore_min(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a1", 1], + [b"a2", 2], + [b"a3", 3], + [b"a4", 4], + ] @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): @@ -2728,12 +2939,20 @@ def test_zunionstore_with_weight(self, r): r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + assert r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + else: + assert r.zrange("d", 0, -1, withscores=True) == [ + [b"a2", 5], + [b"a4", 12], + [b"a3", 20], + [b"a1", 23], + ] @skip_if_server_version_lt("6.1.240") def test_zmscore(self, r): @@ -4108,7 +4327,10 @@ def test_xinfo_stream_full(self, r): info = r.xinfo_stream(stream, full=True) assert info["length"] == 1 - assert m1 in info["entries"] + if is_resp2_connection(r): + assert m1 in info["entries"] + else: + assert m1 in info["entries"][0] assert len(info["groups"]) == 1 @skip_if_server_version_lt("5.0.0") @@ -4249,25 +4471,40 @@ def test_xread(self, r): m1 = r.xadd(stream, {"foo": "bar"}) m2 = r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + strem_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert r.xread(streams={stream: 0}) == expected + res = r.xread(streams={stream: 0}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert r.xread(streams={stream: 0}, count=1) == expected + res = r.xread(streams={stream: 0}, count=1) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} - expected = [[stream.encode(), [get_stream_message(r, stream, m2)]]] + expected_entries = [get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert r.xread(streams={stream: m1}) == expected + res = r.xread(streams={stream: m1}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} # xread starting at the last message returns an empty list - assert r.xread(streams={stream: m2}) == [] + res = r.xread(streams={stream: m2}) + if is_resp2_connection(r): + assert res == [] + else: + assert res == {} @skip_if_server_version_lt("5.0.0") def test_xreadgroup(self, r): @@ -4278,21 +4515,30 @@ def test_xreadgroup(self, r): m2 = r.xadd(stream, {"bing": "baz"}) r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [get_stream_message(r, stream, m1), get_stream_message(r, stream, m2)], - ] + strem_name = stream.encode() + expected_entries = [ + get_stream_message(r, stream, m1), + get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = r.xreadgroup(group, consumer, streams={stream: ">"}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [get_stream_message(r, stream, m1)]]] + expected_entries = [get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) == expected + res = r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} r.xgroup_destroy(stream, group) @@ -4300,27 +4546,37 @@ def test_xreadgroup(self, r): # will only find messages added after this r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + if is_resp2_connection(r): + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == [] + else: + assert r.xreadgroup(group, consumer, streams={stream: ">"}) == {} # xreadgroup with noack does not have any items in the PEL r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") - assert ( - len(r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True)[0][1]) - == 2 - ) - # now there should be nothing pending - assert len(r.xreadgroup(group, consumer, streams={stream: "0"})[0][1]) == 0 + res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert len(res[0][1]) == 2 + # now there should be nothing pending + assert len(empty_res[0][1]) == 0 + else: + assert len(res[strem_name][0]) == 2 + # now there should be nothing pending + assert len(empty_res[strem_name][0]) == 0 r.xgroup_destroy(stream, group) r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] r.xreadgroup(group, consumer, streams={stream: ">"}) r.xtrim(stream, 0) - assert r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = r.xreadgroup(group, consumer, streams={stream: "0"}) + if is_resp2_connection(r): + assert res == [[strem_name, expected_entries]] + else: + assert res == {strem_name: [expected_entries]} @skip_if_server_version_lt("5.0.0") def test_xrevrange(self, r): From 6085a498bd872c793f834bf647b6307f31c3eb7d Mon Sep 17 00:00:00 2001 From: dvora-h Date: Mon, 27 Feb 2023 11:01:35 +0200 Subject: [PATCH 07/21] python resp3 parser --- redis/cluster.py | 1 + redis/connection.py | 7 +++- redis/parsers/__init__.py | 2 + redis/parsers/resp3.py | 85 +++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 5 ++- 5 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 redis/parsers/resp3.py diff --git a/redis/cluster.py b/redis/cluster.py index f8896372c5..56f875eee6 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -138,6 +138,7 @@ def parse_cluster_shards(resp, **options): "port", "retry", "retry_on_timeout", + "protocol", "socket_connect_timeout", "socket_keepalive", "socket_keepalive_options", diff --git a/redis/connection.py b/redis/connection.py index 0fe9389f72..cfa8a9d36c 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -22,7 +22,7 @@ ResponseError, TimeoutError, ) -from .parsers import Encoder, _HiredisParser, _RESP2Parser +from .parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -38,7 +38,7 @@ SENTINEL = object() -DefaultParser: Type[Union[_RESP2Parser, _HiredisParser]] +DefaultParser: Type[Union[_RESP2Parser, _RESP3Parser, _HiredisParser]] if HIREDIS_AVAILABLE: DefaultParser = _HiredisParser else: @@ -292,6 +292,9 @@ def on_connect(self): # if resp version is specified, switch to it if self.protocol != 2: + if isinstance(self._parser, _RESP2Parser): + self.set_parser(_RESP3Parser) + self._parser.on_connect(self) self.send_command("HELLO", self.protocol) response = self.read_response() if response[b"proto"] != int(self.protocol): diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py index 68b32ed5ba..eae1879554 100644 --- a/redis/parsers/__init__.py +++ b/redis/parsers/__init__.py @@ -3,6 +3,7 @@ from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser from .resp2 import _AsyncRESP2Parser, _RESP2Parser +from .resp3 import _RESP3Parser __all__ = [ "AsyncCommandsParser", @@ -13,4 +14,5 @@ "BaseParser", "_HiredisParser", "_RESP2Parser", + "_RESP3Parser", ] diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py new file mode 100644 index 0000000000..013c3e9068 --- /dev/null +++ b/redis/parsers/resp3.py @@ -0,0 +1,85 @@ +from ..exceptions import ConnectionError, InvalidResponse +from .base import _RESPBase +from .socket import SERVER_CLOSED_CONNECTION_ERROR + + +class _RESP3Parser(_RESPBase): + """RESP3 protocol implementation""" + + def read_response(self, disable_decoding=False): + pos = self._buffer.get_pos() + try: + result = self._read_response(disable_decoding=disable_decoding) + except BaseException: + self._buffer.rewind(pos) + raise + else: + self._buffer.purge() + return result + + def _read_response(self, disable_decoding=False): + raw = self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + byte, response = raw[:1], raw[1:] + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = self._buffer.read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response and verbatim strings + elif byte in (b"$", b"="): + response = self._buffer.read(int(response)) + # array response + elif byte == b"*": + response = [ + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + ] + # set response + elif byte == b"~": + response = { + self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + } + # map response + elif byte == b"%": + response = { + self._read_response( + disable_decoding=disable_decoding + ): self._read_response(disable_decoding=disable_decoding) + for i in range(int(response)) + } + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if disable_decoding is False: + response = self.encoder.decode(response) + return response diff --git a/tests/conftest.py b/tests/conftest.py index 557059230d..035dbc85cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -475,5 +475,8 @@ def wait_for_command(client, monitor, command, key=None): def is_resp2_connection(r): - protocol = r.connection_pool.connection_kwargs.get("protocol") + if isinstance(r, redis.Redis): + protocol = r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.RedisCluster): + protocol = r.nodes_manager.connection_kwargs.get("protocol") return protocol == "2" or protocol is None From b1caaee1082e257c66b0f8fea0ef879e0eb70553 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 28 Feb 2023 01:18:25 +0200 Subject: [PATCH 08/21] pipeline --- redis/connection.py | 5 +++-- tests/test_connection.py | 8 ++++---- tests/test_pipeline.py | 2 -- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index cfa8a9d36c..03fdf6ca6d 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -285,7 +285,7 @@ def on_connect(self): # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command("AUTH", auth_args[-1], check_health=False) - auth_response = self.read_response() + auth_response = self.read_response() if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") @@ -297,7 +297,8 @@ def on_connect(self): self._parser.on_connect(self) self.send_command("HELLO", self.protocol) response = self.read_response() - if response[b"proto"] != int(self.protocol): + if (response.get(b"proto") != int(self.protocol) and + response.get("proto") != int(self.protocol)): raise ConnectionError("Invalid RESP version") # if a client_name is given, set it diff --git a/tests/test_connection.py b/tests/test_connection.py index b96c076375..e165cb1ba8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,7 +9,7 @@ from redis.backoff import NoBackoff from redis.connection import Connection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _HiredisParser, _RESP2Parser +from redis.parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE @@ -130,8 +130,8 @@ def test_connect_timeout_error_without_retry(self): @pytest.mark.onlynoncluster @pytest.mark.parametrize( "parser_class", - [_RESP2Parser, _HiredisParser], - ids=["PythonParser", "HiredisParser"], + [_RESP2Parser, _RESP3Parser, _HiredisParser], + ids=["RESP2Parser", "RESP3Parser", "HiredisParser"], ) def test_connection_parse_response_resume(r: redis.Redis, parser_class): """ @@ -151,7 +151,7 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class): ) mock_socket = MockSocket(message, interrupt_every=2) - if isinstance(conn._parser, _RESP2Parser): + if isinstance(conn._parser, _RESP2Parser) or isinstance(conn._parser, _RESP3Parser): conn._parser._buffer._sock = mock_socket else: conn._parser._sock = mock_socket diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 716cd0fbf6..7b98ece692 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -19,7 +19,6 @@ def test_pipeline(self, r): .zadd("z", {"z1": 1}) .zadd("z", {"z2": 4}) .zincrby("z", 1, "z1") - .zrange("z", 0, 5, withscores=True) ) assert pipe.execute() == [ True, @@ -27,7 +26,6 @@ def test_pipeline(self, r): True, True, 2.0, - [(b"z1", 2.0), (b"z2", 4)], ] def test_pipeline_memoryview(self, r): From be9912fbff42c185bef4436aa12c9717a54ec8ef Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 14 Mar 2023 00:40:00 +0200 Subject: [PATCH 09/21] async resp3 parser --- redis/asyncio/client.py | 3 + redis/asyncio/connection.py | 24 +++++++- redis/connection.py | 2 +- redis/parsers/__init__.py | 3 +- redis/parsers/resp3.py | 101 +++++++++++++++++++++++++++++++-- tests/test_asyncio/conftest.py | 21 +++++++ 6 files changed, 144 insertions(+), 10 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3fc7fad83e..e5c9617a60 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -253,6 +253,9 @@ def __init__( self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + if self.connection_pool.connection_kwargs.get("protocol") == "3": + self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + # If using a single connection client, we need to lock creation-of and use-of # the client in order to avoid race conditions such as using asyncio.gather # on a set of redis commands diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2cc2ee7904..70f02389f3 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -43,7 +43,13 @@ from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -from ..parsers import BaseParser, Encoder, _AsyncHiredisParser, _AsyncRESP2Parser +from ..parsers import ( + BaseParser, + Encoder, + _AsyncHiredisParser, + _AsyncRESP2Parser, + _AsyncRESP3Parser, +) SYM_STAR = b"*" SYM_DOLLAR = b"$" @@ -59,7 +65,7 @@ class _Sentinel(enum.Enum): SENTINEL = _Sentinel.sentinel -DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncHiredisParser]] +DefaultParser: Type[Union[_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser]] if HIREDIS_AVAILABLE: DefaultParser = _AsyncHiredisParser else: @@ -104,6 +110,7 @@ class Connection: "last_active_at", "encoder", "ssl_context", + "protocol", "_reader", "_writer", "_parser", @@ -140,6 +147,7 @@ def __init__( redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, ): if (username or password) and credential_provider is not None: raise DataError( @@ -190,6 +198,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 + self.protocol = protocol def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) @@ -344,6 +353,17 @@ async def on_connect(self) -> None: if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") + # if resp version is specified, switch to it + if self.protocol != 2: + if isinstance(self._parser, _AsyncRESP2Parser): + self.set_parser(_AsyncRESP3Parser) + self._parser.on_connect(self) + await self.send_command("HELLO", self.protocol) + response = await self.read_response() + if (response.get(b"proto") != int(self.protocol) and + response.get("proto") != int(self.protocol)): + raise ConnectionError("Invalid RESP version") + # if a client_name is given, set it if self.client_name: await self.send_command("CLIENT", "SETNAME", self.client_name) diff --git a/redis/connection.py b/redis/connection.py index 03fdf6ca6d..83008597e1 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -285,7 +285,7 @@ def on_connect(self): # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command("AUTH", auth_args[-1], check_health=False) - auth_response = self.read_response() + auth_response = self.read_response() if str_if_bytes(auth_response) != "OK": raise AuthenticationError("Invalid Username or Password") diff --git a/redis/parsers/__init__.py b/redis/parsers/__init__.py index eae1879554..0586016a61 100644 --- a/redis/parsers/__init__.py +++ b/redis/parsers/__init__.py @@ -3,12 +3,13 @@ from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser from .resp2 import _AsyncRESP2Parser, _RESP2Parser -from .resp3 import _RESP3Parser +from .resp3 import _AsyncRESP3Parser, _RESP3Parser __all__ = [ "AsyncCommandsParser", "_AsyncHiredisParser", "_AsyncRESP2Parser", + "_AsyncRESP3Parser", "CommandsParser", "Encoder", "BaseParser", diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index 013c3e9068..d79a69e687 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -1,5 +1,8 @@ -from ..exceptions import ConnectionError, InvalidResponse -from .base import _RESPBase +from typing import Any, Union + +from ..exceptions import ConnectionError, InvalidResponse, ResponseError +from ..typing import EncodableT +from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR @@ -61,13 +64,13 @@ def _read_response(self, disable_decoding=False): elif byte == b"*": response = [ self._read_response(disable_decoding=disable_decoding) - for i in range(int(response)) + for _ in range(int(response)) ] # set response elif byte == b"~": response = { self._read_response(disable_decoding=disable_decoding) - for i in range(int(response)) + for _ in range(int(response)) } # map response elif byte == b"%": @@ -75,11 +78,97 @@ def _read_response(self, disable_decoding=False): self._read_response( disable_decoding=disable_decoding ): self._read_response(disable_decoding=disable_decoding) - for i in range(int(response)) + for _ in range(int(response)) + } + else: + raise InvalidResponse(f"Protocol Error: {raw!r}") + + if isinstance(response, bytes) and disable_decoding is False: + response = self.encoder.decode(response) + return response + + +class _AsyncRESP3Parser(_AsyncRESPBase): + async def read_response(self, disable_decoding: bool = False): + if self._chunks: + # augment parsing buffer with previously read data + self._buffer += b"".join(self._chunks) + self._chunks.clear() + self._pos = 0 + response = await self._read_response(disable_decoding=disable_decoding) + # Successfully parsing a response allows us to clear our parsing buffer + self._clear() + return response + + async def _read_response( + self, disable_decoding: bool = False + ) -> Union[EncodableT, ResponseError, None]: + if not self._stream or not self.encoder: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._readline() + response: Any + byte, response = raw[:1], raw[1:] + + # if byte not in (b"-", b"+", b":", b"$", b"*"): + # raise InvalidResponse(f"Protocol Error: {raw!r}") + + # server returned an error + if byte in (b"-", b"!"): + if byte == b"!": + response = await self._read(int(response)) + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + self._clear() # Successful parse + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # null value + elif byte == b"_": + return None + # int and big int values + elif byte in (b":", b"("): + return int(response) + # double value + elif byte == b",": + return float(response) + # bool value + elif byte == b"#": + return response == b"t" + # bulk response and verbatim strings + elif byte in (b"$", b"="): + response = await self._read(int(response)) + # array response + elif byte == b"*": + response = [ + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + ] + # set response + elif byte == b"~": + response = { + (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) + } + # map response + elif byte == b"%": + response = { + (await self._read_response( + disable_decoding=disable_decoding + )): (await self._read_response(disable_decoding=disable_decoding)) + for _ in range(int(response)) } else: raise InvalidResponse(f"Protocol Error: {raw!r}") - if disable_decoding is False: + if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 28a6f0626f..1fb5cf0651 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -235,6 +235,27 @@ async def wait_for_command( if key in monitor_response["command"]: return None +def get_protocol_version(r): + if isinstance(r, redis.Redis): + return r.connection_pool.connection_kwargs.get("protocol") + elif isinstance(r, redis.RedisCluster): + return r.nodes_manager.connection_kwargs.get("protocol") + +def assert_resp_response(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol == "2" or protocol is None: + assert response == resp2_expected + else: + assert response == resp3_expected + + +def assert_resp_response_in(r, response, resp2_expected, resp3_expected): + protocol = get_protocol_version(r) + if protocol == "2" or protocol is None: + assert response in resp2_expected + else: + assert response in resp3_expected + # python 3.6 doesn't have the asynccontextmanager decorator. Provide it here. class AsyncContextManager: From 97c4fedaa6334eb1b6aafee4ae26373570be3ea1 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 14 Mar 2023 10:44:05 +0200 Subject: [PATCH 10/21] some asymc tests --- tests/test_asyncio/test_commands.py | 84 ++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 20 deletions(-) diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 7c6fd45ab9..132bbc2f90 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -18,6 +18,7 @@ skip_unless_arch_bits, ) +from .conftest import assert_resp_response, assert_resp_response_in REDIS_6_VERSION = "5.9.0" @@ -264,7 +265,8 @@ async def test_acl_log(self, r_teardown, create_redis): assert len(await r.acl_log()) == 2 assert len(await r.acl_log(count=1)) == 1 assert isinstance((await r.acl_log())[0], dict) - assert "client-info" in (await r.acl_log(count=1))[0] + expected = (await r.acl_log(count=1))[0] + assert_resp_response_in(r, "client-info", expected, expected.keys()) assert await r.acl_log_reset() @skip_if_server_version_lt(REDIS_6_VERSION) @@ -915,6 +917,19 @@ async def test_pttl_no_key(self, r: redis.Redis): """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" assert await r.pttl("a") == -2 + @skip_if_server_version_lt("6.2.0") + async def test_hrandfield(self, r): + assert await r.hrandfield("key") is None + await r.hset("key", mapping={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + assert await r.hrandfield("key") is not None + assert len(await r.hrandfield("key", 2)) == 2 + # with values + assert_resp_response(r, len(await r.hrandfield("key", 2, True)), 4, 2) + # without duplications + assert len(await r.hrandfield("key", 10)) == 5 + # with duplications + assert len(await r.hrandfield("key", -10)) == 10 + @pytest.mark.onlynoncluster async def test_randomkey(self, r: redis.Redis): assert await r.randomkey() is None @@ -1374,7 +1389,10 @@ async def test_spop_multi_value(self, r: redis.Redis): for value in values: assert value in s - assert await r.spop("a", 1) == list(set(s) - set(values)) + response = await r.spop("a", 1) + assert_resp_response( + r, response, list(set(s) - set(values)), set(s) - set(values) + ) async def test_srandmember(self, r: redis.Redis): s = [b"1", b"2", b"3"] @@ -1412,11 +1430,13 @@ async def test_sunionstore(self, r: redis.Redis): async def test_zadd(self, r: redis.Redis): mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} await r.zadd("a", mapping) - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0]], + ) # error cases with pytest.raises(exceptions.DataError): @@ -1433,23 +1453,24 @@ async def test_zadd(self, r: redis.Redis): async def test_zadd_nx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) async def test_zadd_xx(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 - assert await r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 99.0)], [[b"a1", 99.0]]) async def test_zadd_ch(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 - assert await r.zrange("a", 0, -1, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 99.0), - ] + response = await r.zrange("a", 0, -1, withscores=True) + assert_resp_response( + r, response, [(b"a1", 99.0), (b"a2", 2.0)], [[b"a1", 99.0], [b"a2", 2.0]] + ) async def test_zadd_incr(self, r: redis.Redis): assert await r.zadd("a", {"a1": 1}) == 1 @@ -1473,6 +1494,25 @@ async def test_zcount(self, r: redis.Redis): assert await r.zcount("a", 1, "(" + str(2)) == 1 assert await r.zcount("a", 10, 20) == 0 + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiff(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiff(["a", "b"]) == [b"a3"] + response = await r.zdiff(["a", "b"], withscores=True) + assert_resp_response(r, response, [b"a3", b"3"], [[b"a3", 3.0]]) + + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("6.2.0") + async def test_zdiffstore(self, r): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 1, "a2": 2}) + assert await r.zdiffstore("out", ["a", "b"]) + assert await r.zrange("out", 0, -1) == [b"a3"] + response = await r.zrange("out", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0)], [[b"a3", 3.0]]) + async def test_zincrby(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) assert await r.zincrby("a", 1, "a2") == 3.0 @@ -1492,7 +1532,8 @@ async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"]) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]]) @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): @@ -1500,7 +1541,8 @@ async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]]) @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): @@ -1508,7 +1550,8 @@ async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]]) @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): @@ -1516,7 +1559,8 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 - assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]]) @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): From 984f7d693b49dbc0487ddccce463825c29d0e7c3 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 15 Mar 2023 11:25:26 +0200 Subject: [PATCH 11/21] resp3 parser for async cluster --- redis/asyncio/cluster.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 905ece3965..e4603a0986 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -244,6 +244,7 @@ def __init__( ssl_certfile: Optional[str] = None, ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, + protocol: Optional[int] = 2, ) -> None: if db: raise RedisClusterException( @@ -284,6 +285,7 @@ def __init__( "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, "retry": retry, + "protocol": protocol, } if ssl: From eacdf83cf8e67e605939ad30e69d8a3fa679d9ec Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 21 Mar 2023 12:36:08 +0200 Subject: [PATCH 12/21] async commands tests --- redis/client.py | 17 ++- tests/test_asyncio/test_commands.py | 198 ++++++++++++---------------- tests/test_cluster.py | 68 +++++++--- 3 files changed, 142 insertions(+), 141 deletions(-) diff --git a/redis/client.py b/redis/client.py index 55ae33fc64..c5c977966c 100755 --- a/redis/client.py +++ b/redis/client.py @@ -318,7 +318,10 @@ def parse_xautoclaim(response, **options): def parse_xinfo_stream(response, **options): - data = pairs_to_dict(response, decode_keys=True) + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} if not options.get("full", False): first = data["first-entry"] if first is not None: @@ -584,7 +587,10 @@ def parse_client_kill(response, **options): def parse_acl_getuser(response, **options): if response is None: return None - data = pairs_to_dict(response, decode_keys=True) + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} # convert everything but user-defined data in 'keys' to native strings data["flags"] = list(map(str_if_bytes, data["flags"])) @@ -876,12 +882,13 @@ class AbstractRedis: "XINFO CONSUMERS": lambda r: [ {str_if_bytes(key): value for key, value in x.items()} for x in r ], - "XINFO STREAM": lambda r, **options: { - str_if_bytes(key): str_if_bytes(value) for key, value in r.items() - }, "MEMORY STATS": lambda r: { str_if_bytes(key): value for key, value in r.items() }, + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} + for d in r + ], } diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 132bbc2f90..51d9c9b44d 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1469,7 +1469,7 @@ async def test_zadd_ch(self, r: redis.Redis): assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 response = await r.zrange("a", 0, -1, withscores=True) assert_resp_response( - r, response, [(b"a1", 99.0), (b"a2", 2.0)], [[b"a1", 99.0], [b"a2", 2.0]] + r, response, [(b"a2", 2.0), (b"a1", 99.0)], [[b"a2", 2.0], [b"a1", 99.0]] ) async def test_zadd_incr(self, r: redis.Redis): @@ -1565,18 +1565,22 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmax("a") == [(b"a3", 3)] + response = await r.zpopmax("a") + assert_resp_response(r, response, [(b"a3", 3)], [b"a3", 3.0]) # with count - assert await r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + response = await r.zpopmax("a", count=2) + assert_resp_response(r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]]) @skip_if_server_version_lt("4.9.0") async def test_zpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) - assert await r.zpopmin("a") == [(b"a1", 1)] + response = await r.zpopmin("a") + assert_resp_response(r, response, [(b"a1", 1)], [b"a1", 1.0]) # with count - assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + response = await r.zpopmin("a", count=2) + assert_resp_response(r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]]) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster @@ -1610,20 +1614,16 @@ async def test_zrange(self, r: redis.Redis): assert await r.zrange("a", 1, 2) == [b"a2", b"a3"] # withscores - assert await r.zrange("a", 0, 1, withscores=True) == [ - (b"a1", 1.0), - (b"a2", 2.0), - ] - assert await r.zrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - ] + response = await r.zrange("a", 0, 1, withscores=True) + assert_resp_response(r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]]) + response = await r.zrange("a", 1, 2, withscores=True) + assert_resp_response(r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]]) # custom score function - assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a1", 1), - (b"a2", 2), - ] + # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + # (b"a1", 1), + # (b"a2", 2), + # ] @skip_if_server_version_lt("2.8.9") async def test_zrangebylex(self, r: redis.Redis): @@ -1657,16 +1657,18 @@ async def test_zrangebyscore(self, r: redis.Redis): assert await r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] # withscores - assert await r.zrangebyscore("a", 2, 4, withscores=True) == [ - (b"a2", 2.0), - (b"a3", 3.0), - (b"a4", 4.0), - ] + response = await r.zrangebyscore("a", 2, 4, withscores=True) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]] + ) # custom score function - assert await r.zrangebyscore( + response = await r.zrangebyscore( "a", 2, 4, withscores=True, score_cast_func=int - ) == [(b"a2", 2), (b"a3", 3), (b"a4", 4)] + ) + assert_resp_response( + r, response, [(b"a2", 2), (b"a3", 3), (b"a4", 4)], [[b"a2", 2], [b"a3", 3], [b"a4", 4]] + ) async def test_zrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1714,20 +1716,14 @@ async def test_zrevrange(self, r: redis.Redis): assert await r.zrevrange("a", 1, 2) == [b"a2", b"a1"] # withscores - assert await r.zrevrange("a", 0, 1, withscores=True) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] - assert await r.zrevrange("a", 1, 2, withscores=True) == [ - (b"a2", 2.0), - (b"a1", 1.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True) + assert_resp_response(r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]]) + response = await r.zrevrange("a", 1, 2, withscores=True) + assert_resp_response(r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]]) # custom score function - assert await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) + assert_resp_response(r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]]) async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1737,16 +1733,12 @@ async def test_zrevrangebyscore(self, r: redis.Redis): assert await r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] # withscores - assert await r.zrevrangebyscore("a", 4, 2, withscores=True) == [ - (b"a4", 4.0), - (b"a3", 3.0), - (b"a2", 2.0), - ] + response = await r.zrevrangebyscore("a", 4, 2, withscores=True) + assert_resp_response(r, response, [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]]) # custom score function - assert await r.zrevrangebyscore( - "a", 4, 2, withscores=True, score_cast_func=int - ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + response = await r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) + assert_resp_response(r, response, [(b"a4", 4), (b"a3", 3), (b"a2", 2)], [[b"a4", 4], [b"a3", 3], [b"a2", 2]]) async def test_zrevrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1766,12 +1758,8 @@ async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"]) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 3), - (b"a4", 4), - (b"a3", 8), - (b"a1", 9), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]]) @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): @@ -1779,12 +1767,8 @@ async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 2), - (b"a4", 4), - (b"a3", 5), - (b"a1", 6), - ] + respponse = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, respponse, [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]]) @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): @@ -1792,12 +1776,8 @@ async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a1", 1), - (b"a2", 2), - (b"a3", 3), - (b"a4", 4), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]]) @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): @@ -1805,12 +1785,8 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 - assert await r.zrange("d", 0, -1, withscores=True) == [ - (b"a2", 5), - (b"a4", 12), - (b"a3", 20), - (b"a1", 23), - ] + response = await r.zrange("d", 0, -1, withscores=True) + assert_resp_response(r, response, [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]]) # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") @@ -2805,28 +2781,24 @@ async def test_xread(self, r: redis.Redis): m1 = await r.xadd(stream, {"foo": "bar"}) m2 = await r.xadd(stream, {"bing": "baz"}) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] # xread starting at 0 returns both messages - assert await r.xread(streams={stream: 0}) == expected + res = await r.xread(streams={stream: 0}) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message - assert await r.xread(streams={stream: 0}, count=1) == expected + res = await r.xread(streams={stream: 0}, count=1) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) - expected = [[stream.encode(), [await get_stream_message(r, stream, m2)]]] + expected_entries = [await get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message - assert await r.xread(streams={stream: m1}) == expected - - # xread starting at the last message returns an empty list - assert await r.xread(streams={stream: m2}) == [] + res = await r.xread(streams={stream: m1}) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) @skip_if_server_version_lt("5.0.0") async def test_xreadgroup(self, r: redis.Redis): @@ -2837,27 +2809,24 @@ async def test_xreadgroup(self, r: redis.Redis): m2 = await r.xadd(stream, {"bing": "baz"}) await r.xgroup_create(stream, group, 0) - expected = [ - [ - stream.encode(), - [ - await get_stream_message(r, stream, m1), - await get_stream_message(r, stream, m2), - ], - ] + strem_name = stream.encode() + expected_entries = [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), ] + # xread starting at 0 returns both messages - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) - expected = [[stream.encode(), [await get_stream_message(r, stream, m1)]]] + expected_entries = [await get_stream_message(r, stream, m1)] + # xread with count=1 returns only the first message - assert ( - await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - == expected - ) + res = await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) await r.xgroup_destroy(stream, group) @@ -2865,35 +2834,32 @@ async def test_xreadgroup(self, r: redis.Redis): # will only find messages added after this await r.xgroup_create(stream, group, "$") - expected = [] # xread starting after the last message returns an empty message list - assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: ">"}) + assert_resp_response(r, res, [], {}) # xreadgroup with noack does not have any items in the PEL await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") - assert ( - len( - ( - await r.xreadgroup( - group, consumer, streams={stream: ">"}, noack=True - ) - )[0][1] - ) - == 2 - ) - # now there should be nothing pending - assert ( - len((await r.xreadgroup(group, consumer, streams={stream: "0"}))[0][1]) == 0 - ) + # res = r.xreadgroup(group, consumer, streams={stream: ">"}, noack=True) + # empty_res = r.xreadgroup(group, consumer, streams={stream: "0"}) + # if is_resp2_connection(r): + # assert len(res[0][1]) == 2 + # # now there should be nothing pending + # assert len(empty_res[0][1]) == 0 + # else: + # assert len(res[strem_name][0]) == 2 + # # now there should be nothing pending + # assert len(empty_res[strem_name][0]) == 0 await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, "0") # delete all the messages in the stream - expected = [[stream.encode(), [(m1, {}), (m2, {})]]] + expected_entries = [(m1, {}), (m2, {})] await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) - assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + res = await r.xreadgroup(group, consumer, streams={stream: "0"}) + assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) @skip_if_server_version_lt("5.0.0") async def test_xrevrange(self, r: redis.Redis): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 3dc23bb852..ca1e33a793 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -42,6 +42,7 @@ skip_if_server_version_lt, skip_unless_arch_bits, wait_for_command, + is_resp2_connection, ) default_host = "127.0.0.1" @@ -1723,7 +1724,10 @@ def test_cluster_zdiff(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2, "a3": 3}) r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiff(["{foo}a", "{foo}b"]) == [b"a3"] - assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + if is_resp2_connection(r): + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [b"a3", b"3"] + else: + assert r.zdiff(["{foo}a", "{foo}b"], withscores=True) == [[b"a3", 3.0]] @skip_if_server_version_lt("6.2.0") def test_cluster_zdiffstore(self, r): @@ -1731,7 +1735,10 @@ def test_cluster_zdiffstore(self, r): r.zadd("{foo}b", {"a1": 1, "a2": 2}) assert r.zdiffstore("{foo}out", ["{foo}a", "{foo}b"]) assert r.zrange("{foo}out", 0, -1) == [b"a3"] - assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + if is_resp2_connection(r): + assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b"a3", 3.0)] + else: + assert r.zrange("{foo}out", 0, -1, withscores=True) == [[b"a3", 3.0]] @skip_if_server_version_lt("6.2.0") def test_cluster_zinter(self, r): @@ -1742,24 +1749,45 @@ def test_cluster_zinter(self, r): # invalid aggregation with pytest.raises(DataError): r.zinter(["{foo}a", "{foo}b", "{foo}c"], aggregate="foo", withscores=True) - # aggregate with SUM - assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - (b"a3", 8), - (b"a1", 9), - ] - # aggregate with MAX - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [(b"a3", 5), (b"a1", 6)] - # aggregate with MIN - assert r.zinter( - ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [(b"a1", 1), (b"a3", 1)] - # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + if is_resp2_connection(r): + # aggregate with SUM + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + (b"a3", 8), + (b"a1", 9), + ] + # aggregate with MAX + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [(b"a3", 5), (b"a1", 6)] + # aggregate with MIN + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [(b"a1", 1), (b"a3", 1)] + # with weights + assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + (b"a3", 20), + (b"a1", 23), + ] + else: + # aggregate with SUM + assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ + [b"a3", 8] + [b"a1", 9] + ] + # aggregate with MAX + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True + ) == [[b"a3", 5] [b"a1", 6]] + # aggregate with MIN + assert r.zinter( + ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True + ) == [[b"a1", 1] [b"a3", 1]] + # with weights + assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ + [b"a3", 2], + [b"a1", 2], + ] + def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) From 1591480259d15b2b065a22c4c1bc1602fb12d3c9 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 21 Mar 2023 16:38:03 +0200 Subject: [PATCH 13/21] linters --- redis/asyncio/connection.py | 5 +- redis/client.py | 3 +- redis/connection.py | 7 +- redis/parsers/base.py | 4 +- redis/parsers/hiredis.py | 3 +- redis/parsers/resp3.py | 6 +- tests/test_asyncio/conftest.py | 4 +- tests/test_asyncio/test_commands.py | 125 ++++++++++++++++++++++------ tests/test_cluster.py | 24 +++--- tests/test_connection.py | 6 +- 10 files changed, 127 insertions(+), 60 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 7120636c1a..d9c95834d5 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -365,8 +365,9 @@ async def on_connect(self) -> None: self._parser.on_connect(self) await self.send_command("HELLO", self.protocol) response = await self.read_response() - if (response.get(b"proto") != int(self.protocol) and - response.get("proto") != int(self.protocol)): + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): raise ConnectionError("Invalid RESP version") # if a client_name is given, set it diff --git a/redis/client.py b/redis/client.py index c5c977966c..15dddc9bd7 100755 --- a/redis/client.py +++ b/redis/client.py @@ -886,8 +886,7 @@ class AbstractRedis: str_if_bytes(key): value for key, value in r.items() }, "XINFO GROUPS": lambda r: [ - {str_if_bytes(key): value for key, value in d.items()} - for d in r + {str_if_bytes(key): value for key, value in d.items()} for d in r ], } diff --git a/redis/connection.py b/redis/connection.py index 2c86a077bd..85509f7ef7 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2,8 +2,6 @@ import os import socket import ssl -import threading -import weakref import sys import threading import weakref @@ -310,8 +308,9 @@ def on_connect(self): self._parser.on_connect(self) self.send_command("HELLO", self.protocol) response = self.read_response() - if (response.get(b"proto") != int(self.protocol) and - response.get("proto") != int(self.protocol)): + if response.get(b"proto") != int(self.protocol) and response.get( + "proto" + ) != int(self.protocol): raise ConnectionError("Invalid RESP version") # if a client_name is given, set it diff --git a/redis/parsers/base.py b/redis/parsers/base.py index 53bbae54b3..b98a44ef2f 100644 --- a/redis/parsers/base.py +++ b/redis/parsers/base.py @@ -1,7 +1,7 @@ +import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import List, Optional, Union -import sys if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -166,7 +166,7 @@ def _clear(self): self._buffer = b"" self._chunks.clear() - def on_connect(self, connection: "Connection"): + def on_connect(self, connection): """Called when the stream connects""" self._stream = connection._reader if self._stream is None: diff --git a/redis/parsers/hiredis.py b/redis/parsers/hiredis.py index 0963b002a5..b3247b71ec 100644 --- a/redis/parsers/hiredis.py +++ b/redis/parsers/hiredis.py @@ -1,12 +1,13 @@ import asyncio import socket -from typing import Callable, List, Optional, Union import sys +from typing import Callable, List, Optional, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout + from redis.compat import TypedDict from ..exceptions import ConnectionError, InvalidResponse, RedisError diff --git a/redis/parsers/resp3.py b/redis/parsers/resp3.py index d79a69e687..2753d39f1a 100644 --- a/redis/parsers/resp3.py +++ b/redis/parsers/resp3.py @@ -161,9 +161,9 @@ async def _read_response( # map response elif byte == b"%": response = { - (await self._read_response( - disable_decoding=disable_decoding - )): (await self._read_response(disable_decoding=disable_decoding)) + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) + ) for _ in range(int(response)) } else: diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 1fb5cf0651..e51836294b 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -235,12 +235,14 @@ async def wait_for_command( if key in monitor_response["command"]: return None + def get_protocol_version(r): if isinstance(r, redis.Redis): return r.connection_pool.connection_kwargs.get("protocol") elif isinstance(r, redis.RedisCluster): return r.nodes_manager.connection_kwargs.get("protocol") - + + def assert_resp_response(r, response, resp2_expected, resp3_expected): protocol = get_protocol_version(r) if protocol == "2" or protocol is None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 51d9c9b44d..866929b2e4 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -19,6 +19,7 @@ ) from .conftest import assert_resp_response, assert_resp_response_in + REDIS_6_VERSION = "5.9.0" @@ -1533,7 +1534,9 @@ async def test_zinterstore_sum(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"]) == 2 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]]) + assert_resp_response( + r, response, [(b"a3", 8), (b"a1", 9)], [[b"a3", 8.0], [b"a1", 9.0]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_max(self, r: redis.Redis): @@ -1542,7 +1545,9 @@ async def test_zinterstore_max(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]]) + assert_resp_response( + r, response, [(b"a3", 5), (b"a1", 6)], [[b"a3", 5], [b"a1", 6]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_min(self, r: redis.Redis): @@ -1551,7 +1556,9 @@ async def test_zinterstore_min(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]]) + assert_resp_response( + r, response, [(b"a1", 1), (b"a3", 3)], [[b"a1", 1], [b"a3", 3]] + ) @pytest.mark.onlynoncluster async def test_zinterstore_with_weight(self, r: redis.Redis): @@ -1560,7 +1567,9 @@ async def test_zinterstore_with_weight(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]]) + assert_resp_response( + r, response, [(b"a3", 20), (b"a1", 23)], [[b"a3", 20], [b"a1", 23]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmax(self, r: redis.Redis): @@ -1570,7 +1579,9 @@ async def test_zpopmax(self, r: redis.Redis): # with count response = await r.zpopmax("a", count=2) - assert_resp_response(r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]]) + assert_resp_response( + r, response, [(b"a2", 2), (b"a1", 1)], [[b"a2", 2], [b"a1", 1]] + ) @skip_if_server_version_lt("4.9.0") async def test_zpopmin(self, r: redis.Redis): @@ -1580,7 +1591,9 @@ async def test_zpopmin(self, r: redis.Redis): # with count response = await r.zpopmin("a", count=2) - assert_resp_response(r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]]) + assert_resp_response( + r, response, [(b"a2", 2), (b"a3", 3)], [[b"a2", 2], [b"a3", 3]] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster @@ -1615,9 +1628,13 @@ async def test_zrange(self, r: redis.Redis): # withscores response = await r.zrange("a", 0, 1, withscores=True) - assert_resp_response(r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]]) + assert_resp_response( + r, response, [(b"a1", 1.0), (b"a2", 2.0)], [[b"a1", 1.0], [b"a2", 2.0]] + ) response = await r.zrange("a", 1, 2, withscores=True) - assert_resp_response(r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]]) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a3", 3.0)], [[b"a2", 2.0], [b"a3", 3.0]] + ) # custom score function # assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ @@ -1659,7 +1676,10 @@ async def test_zrangebyscore(self, r: redis.Redis): # withscores response = await r.zrangebyscore("a", 2, 4, withscores=True) assert_resp_response( - r, response, [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]] + r, + response, + [(b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], ) # custom score function @@ -1667,7 +1687,10 @@ async def test_zrangebyscore(self, r: redis.Redis): "a", 2, 4, withscores=True, score_cast_func=int ) assert_resp_response( - r, response, [(b"a2", 2), (b"a3", 3), (b"a4", 4)], [[b"a2", 2], [b"a3", 3], [b"a4", 4]] + r, + response, + [(b"a2", 2), (b"a3", 3), (b"a4", 4)], + [[b"a2", 2], [b"a3", 3], [b"a4", 4]], ) async def test_zrank(self, r: redis.Redis): @@ -1717,13 +1740,19 @@ async def test_zrevrange(self, r: redis.Redis): # withscores response = await r.zrevrange("a", 0, 1, withscores=True) - assert_resp_response(r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]]) + assert_resp_response( + r, response, [(b"a3", 3.0), (b"a2", 2.0)], [[b"a3", 3.0], [b"a2", 2.0]] + ) response = await r.zrevrange("a", 1, 2, withscores=True) - assert_resp_response(r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]]) + assert_resp_response( + r, response, [(b"a2", 2.0), (b"a1", 1.0)], [[b"a2", 2.0], [b"a1", 1.0]] + ) # custom score function response = await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) - assert_resp_response(r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]]) + assert_resp_response( + r, response, [(b"a3", 3), (b"a2", 2)], [[b"a3", 3], [b"a2", 2]] + ) async def test_zrevrangebyscore(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1734,11 +1763,23 @@ async def test_zrevrangebyscore(self, r: redis.Redis): # withscores response = await r.zrevrangebyscore("a", 4, 2, withscores=True) - assert_resp_response(r, response, [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]]) + assert_resp_response( + r, + response, + [(b"a4", 4.0), (b"a3", 3.0), (b"a2", 2.0)], + [[b"a4", 4.0], [b"a3", 3.0], [b"a2", 2.0]], + ) # custom score function - response = await r.zrevrangebyscore("a", 4, 2, withscores=True, score_cast_func=int) - assert_resp_response(r, response, [(b"a4", 4), (b"a3", 3), (b"a2", 2)], [[b"a4", 4], [b"a3", 3], [b"a2", 2]]) + response = await r.zrevrangebyscore( + "a", 4, 2, withscores=True, score_cast_func=int + ) + assert_resp_response( + r, + response, + [(b"a4", 4), (b"a3", 3), (b"a2", 2)], + [[b"a4", 4], [b"a3", 3], [b"a2", 2]], + ) async def test_zrevrank(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) @@ -1759,7 +1800,12 @@ async def test_zunionstore_sum(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"]) == 4 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]]) + assert_resp_response( + r, + response, + [(b"a2", 3.0), (b"a4", 4.0), (b"a3", 8.0), (b"a1", 9.0)], + [[b"a2", 3.0], [b"a4", 4.0], [b"a3", 8.0], [b"a1", 9.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_max(self, r: redis.Redis): @@ -1768,7 +1814,12 @@ async def test_zunionstore_max(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 respponse = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, respponse, [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]]) + assert_resp_response( + r, + respponse, + [(b"a2", 2.0), (b"a4", 4.0), (b"a3", 5.0), (b"a1", 6.0)], + [[b"a2", 2.0], [b"a4", 4.0], [b"a3", 5.0], [b"a1", 6.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_min(self, r: redis.Redis): @@ -1777,7 +1828,12 @@ async def test_zunionstore_min(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]]) + assert_resp_response( + r, + response, + [(b"a1", 1.0), (b"a2", 2.0), (b"a3", 3.0), (b"a4", 4.0)], + [[b"a1", 1.0], [b"a2", 2.0], [b"a3", 3.0], [b"a4", 4.0]], + ) @pytest.mark.onlynoncluster async def test_zunionstore_with_weight(self, r: redis.Redis): @@ -1786,7 +1842,12 @@ async def test_zunionstore_with_weight(self, r: redis.Redis): await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 response = await r.zrange("d", 0, -1, withscores=True) - assert_resp_response(r, response, [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]]) + assert_resp_response( + r, + response, + [(b"a2", 5.0), (b"a4", 12.0), (b"a3", 20.0), (b"a1", 23.0)], + [[b"a2", 5.0], [b"a4", 12.0], [b"a3", 20.0], [b"a1", 23.0]], + ) # HYPERLOGLOG TESTS @skip_if_server_version_lt("2.8.9") @@ -2788,17 +2849,23 @@ async def test_xread(self, r: redis.Redis): ] # xread starting at 0 returns both messages res = await r.xread(streams={stream: 0}) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) expected_entries = [await get_stream_message(r, stream, m1)] # xread starting at 0 and count=1 returns only the first message res = await r.xread(streams={stream: 0}, count=1) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) expected_entries = [await get_stream_message(r, stream, m2)] # xread starting at m1 returns only the second message res = await r.xread(streams={stream: m1}) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xreadgroup(self, r: redis.Redis): @@ -2817,7 +2884,9 @@ async def test_xreadgroup(self, r: redis.Redis): # xread starting at 0 returns both messages res = await r.xreadgroup(group, consumer, streams={stream: ">"}) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) await r.xgroup_destroy(stream, group) await r.xgroup_create(stream, group, 0) @@ -2826,7 +2895,9 @@ async def test_xreadgroup(self, r: redis.Redis): # xread with count=1 returns only the first message res = await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) await r.xgroup_destroy(stream, group) @@ -2859,7 +2930,9 @@ async def test_xreadgroup(self, r: redis.Redis): await r.xreadgroup(group, consumer, streams={stream: ">"}) await r.xtrim(stream, 0) res = await r.xreadgroup(group, consumer, streams={stream: "0"}) - assert_resp_response(r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]}) + assert_resp_response( + r, res, [[strem_name, expected_entries]], {strem_name: [expected_entries]} + ) @skip_if_server_version_lt("5.0.0") async def test_xrevrange(self, r: redis.Redis): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 8f16cd3fd7..993bfab9cd 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -39,11 +39,11 @@ from .conftest import ( _get_client, + is_resp2_connection, skip_if_redis_enterprise, skip_if_server_version_lt, skip_unless_arch_bits, wait_for_command, - is_resp2_connection, ) default_host = "127.0.0.1" @@ -1765,30 +1765,26 @@ def test_cluster_zinter(self, r): ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True ) == [(b"a1", 1), (b"a3", 1)] # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - (b"a3", 20), - (b"a1", 23), - ] + assert r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [(b"a3", 20), (b"a1", 23)] else: # aggregate with SUM assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - [b"a3", 8] - [b"a1", 9] + [b"a3", 8][b"a1", 9] ] # aggregate with MAX assert r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MAX", withscores=True - ) == [[b"a3", 5] [b"a1", 6]] + ) == [[b"a3", 5], [b"a1", 6]] # aggregate with MIN assert r.zinter( ["{foo}a", "{foo}b", "{foo}c"], aggregate="MIN", withscores=True - ) == [[b"a1", 1] [b"a3", 1]] + ) == [[b"a1", 1], [b"a3", 1]] # with weights - assert r.zinter({"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True) == [ - [b"a3", 2], - [b"a1", 2], - ] - + assert r.zinter( + {"{foo}a": 1, "{foo}b": 2, "{foo}c": 3}, withscores=True + ) == [[b"a3", 2], [b"a1", 2]] def test_cluster_zinterstore_sum(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 1, "a3": 1}) diff --git a/tests/test_connection.py b/tests/test_connection.py index c284499c62..facd425061 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -7,11 +7,7 @@ import redis from redis.backoff import NoBackoff -from redis.connection import ( - Connection, - SSLConnection, - UnixDomainSocketConnection, -) +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.retry import Retry From 88074b7be1abc9af0ee215ea2d9d845cb48c9ab1 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 21 Mar 2023 16:50:46 +0200 Subject: [PATCH 14/21] linters --- tests/test_asyncio/test_connection.py | 10 +++++----- tests/test_cluster.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 4199a167d8..6f3e9104af 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -9,14 +9,12 @@ from redis.asyncio import Redis from redis.asyncio.connection import ( Connection, - HiredisParser, - PythonParser, UnixDomainSocketConnection, ) from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _AsyncRESP2Parser +from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser, _AsyncRESP3Parser from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt @@ -197,7 +195,9 @@ async def test_connection_parse_response_resume(r: redis.Redis): @pytest.mark.onlynoncluster @pytest.mark.parametrize( - "parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"] + "parser_class", + [_AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncHiredisParser], + ids=["AsyncRESP2Parser", "AsyncRESP3Parser", "AsyncHiredisParser"], ) async def test_connection_disconect_race(parser_class): """ @@ -211,7 +211,7 @@ async def test_connection_disconect_race(parser_class): This test verifies that a read in progress can finish even if the `disconnect()` method is called. """ - if parser_class == HiredisParser and not HIREDIS_AVAILABLE: + if parser_class == _AsyncHiredisParser and not HIREDIS_AVAILABLE: pytest.skip("Hiredis not available") args = {} diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 993bfab9cd..ccc7040559 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1771,7 +1771,7 @@ def test_cluster_zinter(self, r): else: # aggregate with SUM assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - [b"a3", 8][b"a1", 9] + [b"a3", 8], [b"a1", 9] ] # aggregate with MAX assert r.zinter( From 15fbb676b32d37b930ddab4454a8644adbea69a6 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 21 Mar 2023 17:02:48 +0200 Subject: [PATCH 15/21] linters --- tests/test_asyncio/test_connection.py | 5 +---- tests/test_cluster.py | 3 ++- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 6f3e9104af..5e47eedabc 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -7,10 +7,7 @@ import redis from redis.asyncio import Redis -from redis.asyncio.connection import ( - Connection, - UnixDomainSocketConnection, -) +from redis.asyncio.connection import Connection, UnixDomainSocketConnection from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError diff --git a/tests/test_cluster.py b/tests/test_cluster.py index ccc7040559..4a43eaea21 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1771,7 +1771,8 @@ def test_cluster_zinter(self, r): else: # aggregate with SUM assert r.zinter(["{foo}a", "{foo}b", "{foo}c"], withscores=True) == [ - [b"a3", 8], [b"a1", 9] + [b"a3", 8], + [b"a1", 9], ] # aggregate with MAX assert r.zinter( From c2656fc8208c9ed18a53108ff9c131a486a84308 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Tue, 21 Mar 2023 17:49:02 +0200 Subject: [PATCH 16/21] fix ModuleNotFoundError --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index a0710d3b73..52e0a8a901 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ "redis.commands.search", "redis.commands.timeseries", "redis.commands.graph", + "redis.parsers", ] ), url="https://github.com/redis/redis-py", From c3fcbeab41a947c41fef0f782f54278012d27904 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 22 Mar 2023 02:55:08 +0200 Subject: [PATCH 17/21] fix tests --- tests/test_asyncio/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index e51836294b..b49b8107dd 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -245,7 +245,7 @@ def get_protocol_version(r): def assert_resp_response(r, response, resp2_expected, resp3_expected): protocol = get_protocol_version(r) - if protocol == "2" or protocol is None: + if protocol in [2, "2", None]: assert response == resp2_expected else: assert response == resp3_expected From 521736fa850dbce2fd1fa5c52f314754b3dd0a0b Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 22 Mar 2023 02:57:39 +0200 Subject: [PATCH 18/21] fix assert_resp_response_in --- tests/test_asyncio/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index b49b8107dd..e8ab6b297f 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -253,7 +253,7 @@ def assert_resp_response(r, response, resp2_expected, resp3_expected): def assert_resp_response_in(r, response, resp2_expected, resp3_expected): protocol = get_protocol_version(r) - if protocol == "2" or protocol is None: + if protocol in [2, "2", None]: assert response in resp2_expected else: assert response in resp3_expected From b81056098ecec565d94910657729eeca879afee2 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Wed, 22 Mar 2023 13:45:39 +0200 Subject: [PATCH 19/21] fix command_getkeys in cluster --- redis/parsers/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/parsers/commands.py b/redis/parsers/commands.py index 496336acaa..2ea29a75ae 100644 --- a/redis/parsers/commands.py +++ b/redis/parsers/commands.py @@ -126,7 +126,7 @@ def _get_moveable_keys(self, redis_conn, *args): # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] pieces = args[0].split() + list(args[1:]) try: - keys = redis_conn.command_getkeys(*pieces) + keys = redis_conn.execute_command("COMMAND GETKEYS", *pieces) except ResponseError as e: message = e.__str__() if ( From 353f218fbed59bf234b747536acedfd80d397a0b Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Wed, 22 Mar 2023 14:11:08 +0200 Subject: [PATCH 20/21] fail-fast false --- .github/workflows/integration.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 0f9db8fb1a..f49a4fcd46 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -51,6 +51,7 @@ jobs: timeout-minutes: 30 strategy: max-parallel: 15 + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] test-type: ['standalone', 'cluster'] @@ -108,6 +109,7 @@ jobs: name: Install package from commit hash runs-on: ubuntu-latest strategy: + fail-fast: false matrix: python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', 'pypy-3.7', 'pypy-3.8', 'pypy-3.9'] steps: From 2842f9d0a070f4627f6ff156079304d4bdbadad3 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 23 Mar 2023 12:12:36 +0200 Subject: [PATCH 21/21] version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fc63ae28c4..f37e77df67 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ long_description_content_type="text/markdown", keywords=["Redis", "key-value store", "database"], license="MIT", - version="4.5.3", + version="5.0.0b1", packages=find_packages( include=[ "redis",