Skip to content

Commit 1b0d29b

Browse files
committed
Fix: Issue with search dialect 3 and JSON (resolves #140)
1 parent ce0f710 commit 1b0d29b

File tree

2 files changed

+157
-1
lines changed

2 files changed

+157
-1
lines changed

redisvl/index/index.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@
3232
logger = get_logger(__name__)
3333

3434

35+
def _handle_dialect_3(result: Dict[str, Any]) -> Dict[str, Any]:
36+
"""
37+
Handle dialect 3 responses by converting JSON-encoded list values to strings.
38+
39+
Each JSON-encoded string in the result that is a list will be converted:
40+
- If the list has one item, it is unpacked.
41+
- If the list has multiple items, they are joined into a single comma-separated string.
42+
43+
Args:
44+
result (Dict[str, Any]): The dictionary containing the results to process.
45+
46+
Returns:
47+
Dict[str, Any]: The processed dictionary with updated values.
48+
"""
49+
for field, value in result.items():
50+
if isinstance(value, str):
51+
try:
52+
parsed_value = json.loads(value)
53+
except json.JSONDecodeError:
54+
continue # Skip processing if value is not valid JSON
55+
56+
if isinstance(parsed_value, list):
57+
# Use a single value if the list contains only one item, else join all items.
58+
result[field] = (
59+
parsed_value[0]
60+
if len(parsed_value) == 1
61+
else ", ".join(map(str, parsed_value))
62+
)
63+
64+
return result
65+
66+
3567
def process_results(
3668
results: "Result", query: BaseQuery, storage_type: StorageType
3769
) -> List[Dict[str, Any]]:
@@ -81,7 +113,13 @@ def _process(doc: "Document") -> Dict[str, Any]:
81113

82114
return doc_dict
83115

84-
return [_process(doc) for doc in results.docs]
116+
processed_results = [_process(doc) for doc in results.docs]
117+
118+
# Handle dialect 3 responses
119+
if query._dialect == 3:
120+
processed_results = [_handle_dialect_3(result) for result in processed_results]
121+
122+
return processed_results
85123

86124

87125
def check_modules_present():

tests/integration/test_dialects.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import json
2+
3+
import pytest
4+
from redis import Redis
5+
from redis.commands.search.query import Query
6+
7+
from redisvl.index import SearchIndex
8+
from redisvl.query import FilterQuery, VectorQuery
9+
from redisvl.query.filter import Tag
10+
from redisvl.schema.schema import IndexSchema
11+
12+
13+
@pytest.fixture
14+
def sample_data():
15+
return [
16+
{
17+
"name": "Noise-cancelling Bluetooth headphones",
18+
"description": "Wireless Bluetooth headphones with noise-cancelling technology",
19+
"connection": {"wireless": True, "type": "Bluetooth"},
20+
"price": 99.98,
21+
"stock": 25,
22+
"colors": ["black", "silver"],
23+
"embedding": [0.87, -0.15, 0.55, 0.03],
24+
"embeddings": [[0.56, -0.34, 0.69, 0.02], [0.94, -0.23, 0.45, 0.19]],
25+
},
26+
{
27+
"name": "Wireless earbuds",
28+
"description": "Wireless Bluetooth in-ear headphones",
29+
"connection": {"wireless": True, "type": "Bluetooth"},
30+
"price": 64.99,
31+
"stock": 17,
32+
"colors": ["red", "black", "white"],
33+
"embedding": [-0.7, -0.51, 0.88, 0.14],
34+
"embeddings": [[0.54, -0.14, 0.79, 0.92], [0.94, -0.93, 0.45, 0.16]],
35+
},
36+
]
37+
38+
39+
@pytest.fixture
40+
def schema_dict():
41+
return {
42+
"index": {"name": "products", "prefix": "product", "storage_type": "json"},
43+
"fields": [
44+
{"name": "name", "type": "text"},
45+
{"name": "description", "type": "text"},
46+
{"name": "connection_type", "path": "$.connection.type", "type": "tag"},
47+
{"name": "price", "type": "numeric"},
48+
{"name": "stock", "type": "numeric"},
49+
{"name": "color", "path": "$.colors.*", "type": "tag"},
50+
{
51+
"name": "embedding",
52+
"type": "vector",
53+
"attrs": {"dims": 4, "algorithm": "flat", "distance_metric": "cosine"},
54+
},
55+
{
56+
"name": "embeddings",
57+
"path": "$.embeddings[*]",
58+
"type": "vector",
59+
"attrs": {"dims": 4, "algorithm": "hnsw", "distance_metric": "l2"},
60+
},
61+
],
62+
}
63+
64+
65+
@pytest.fixture
66+
def index(sample_data, redis_url, schema_dict):
67+
index_schema = IndexSchema.from_dict(schema_dict)
68+
redis_client = Redis.from_url(redis_url)
69+
index = SearchIndex(index_schema, redis_client)
70+
index.create(overwrite=True, drop=True)
71+
index.load(sample_data)
72+
yield index
73+
index.delete(drop=True)
74+
75+
76+
def test_dialect_3_json(index, sample_data):
77+
# Create a VectorQuery with dialect 3
78+
vector_query = VectorQuery(
79+
vector=[0.23, 0.12, -0.03, 0.98],
80+
vector_field_name="embedding",
81+
return_fields=["name", "description", "price"],
82+
dialect=3,
83+
)
84+
85+
# Execute the query
86+
results = index.query(vector_query)
87+
88+
# Print the results
89+
print("VectorQuery Results:")
90+
print(results)
91+
92+
# Assert the expected format of the results
93+
assert len(results) > 0
94+
for result in results:
95+
assert not isinstance(result["name"], list)
96+
assert not isinstance(result["description"], list)
97+
assert not isinstance(result["price"], (list, str))
98+
99+
# Create a FilterQuery with dialect 3
100+
filter_query = FilterQuery(
101+
filter_expression=Tag("color") == "black",
102+
return_fields=["name", "description", "price"],
103+
dialect=3,
104+
)
105+
106+
# Execute the query
107+
results = index.query(filter_query)
108+
109+
# Print the results
110+
print("FilterQuery Results:")
111+
print(results)
112+
113+
# Assert the expected format of the results
114+
assert len(results) > 0
115+
for result in results:
116+
assert not isinstance(result["name"], list)
117+
assert not isinstance(result["description"], list)
118+
assert not isinstance(result["price"], (list, str))

0 commit comments

Comments
 (0)