Skip to content

Commit 2e28576

Browse files
committed
Fix vector type fields should not be encoded as strings (redis#2772)
1 parent 0d47d65 commit 2e28576

File tree

6 files changed

+62
-15
lines changed

6 files changed

+62
-15
lines changed

CHANGES

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Fix #2772, Fix vector type fields should not be encoded as strings
12
* Update `ResponseT` type hint
23
* Allow to control the minimum SSL version
34
* Add an optional lock_name attribute to LockError.
@@ -59,7 +60,7 @@
5960
* Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458)
6061
* Added a replacement for the default cluster node in the event of failure (#2463)
6162
* Fix for Unhandled exception related to self.host with unix socket (#2496)
62-
* Improve error output for master discovery
63+
* Improve error output for master discovery
6364
* Make `ClusterCommandsProtocol` an actual Protocol
6465
* Add `sum` to DUPLICATE_POLICY documentation of `TS.CREATE`, `TS.ADD` and `TS.ALTER`
6566
* Prevent async ClusterPipeline instances from becoming "false-y" in case of empty command stack (#3061)

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ ujson>=4.2.0
1616
wheel>=0.30.0
1717
urllib3<2
1818
uvloop
19+
numpy>=1.24.4

redis/_parsers/encoders.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ def decode(self, value, force=False):
4040
if isinstance(value, memoryview):
4141
value = value.tobytes()
4242
if isinstance(value, bytes):
43-
value = value.decode(self.encoding, self.encoding_errors)
43+
try:
44+
value = value.decode(self.encoding, self.encoding_errors)
45+
except UnicodeDecodeError:
46+
# Return the bytes unmodified
47+
return value
4448
return value

redis/_parsers/hiredis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def read_response(self, disable_decoding=False):
123123
if disable_decoding:
124124
response = self._reader.gets(False)
125125
else:
126-
response = self._reader.gets()
126+
try:
127+
response = self._reader.gets()
128+
except UnicodeDecodeError:
129+
response = self._reader.gets(False)
127130
# if the response is a ConnectionError or the response is a list and
128131
# the first item is a ConnectionError, raise it as something bad
129132
# happened

redis/commands/search/result.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,11 @@ def __init__(
3939

4040
fields = {}
4141
if hascontent and res[i + fields_offset] is not None:
42-
fields = (
43-
dict(
44-
dict(
45-
zip(
46-
map(to_string, res[i + fields_offset][::2]),
47-
map(to_string, res[i + fields_offset][1::2]),
48-
)
49-
)
50-
)
51-
if hascontent
52-
else {}
53-
)
42+
for j in range(0, len(res[i + fields_offset]), 2):
43+
key = to_string(res[i + fields_offset][j])
44+
value = res[i + fields_offset][j + 1]
45+
fields[key] = value
46+
5447
try:
5548
del fields["id"]
5649
except KeyError:

tests/test_search.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from io import TextIOWrapper
66

7+
import numpy as np
78
import pytest
89
import redis
910
import redis.commands.search
@@ -2282,3 +2283,47 @@ def test_geoshape(client: redis.Redis):
22822283
assert result.docs[0]["id"] == "small"
22832284
result = client.ft().search(q2, query_params=qp2)
22842285
assert len(result.docs) == 2
2286+
2287+
2288+
@pytest.mark.redismod
2289+
def test_vector_storage_and_retrieval(client):
2290+
# Constants
2291+
INDEX_NAME = "vector_index"
2292+
DOC_PREFIX = "doc:"
2293+
VECTOR_DIMENSIONS = 4
2294+
VECTOR_FIELD_NAME = "my_vector"
2295+
2296+
# Create index
2297+
client.ft(INDEX_NAME).create_index(
2298+
(
2299+
VectorField(
2300+
VECTOR_FIELD_NAME,
2301+
"FLAT",
2302+
{
2303+
"TYPE": "FLOAT32",
2304+
"DIM": VECTOR_DIMENSIONS,
2305+
"DISTANCE_METRIC": "COSINE",
2306+
},
2307+
),
2308+
),
2309+
definition=IndexDefinition(prefix=[DOC_PREFIX], index_type=IndexType.HASH),
2310+
)
2311+
2312+
# Add a document with a vector value
2313+
vector_data = [0.1, 0.2, 0.3, 0.4]
2314+
client.hset(
2315+
f"{DOC_PREFIX}1",
2316+
mapping={VECTOR_FIELD_NAME: np.array(vector_data, dtype=np.float32).tobytes()},
2317+
)
2318+
2319+
# Perform a search to retrieve the document
2320+
query = Query("*").return_fields(VECTOR_FIELD_NAME).dialect(2)
2321+
res = client.ft(INDEX_NAME).search(query)
2322+
2323+
# Assert that the document is retrieved and the vector matches the original data
2324+
assert res.total == 1
2325+
assert res.docs[0].id == f"{DOC_PREFIX}1"
2326+
retrieved_vector_data = np.frombuffer(
2327+
res.docs[0].__dict__[VECTOR_FIELD_NAME], dtype=np.float32
2328+
)
2329+
assert np.allclose(retrieved_vector_data, vector_data)

0 commit comments

Comments
 (0)