Skip to content

Commit 458be76

Browse files
Add optional sorting option to RedisVL queries (#148)
This PR adds the ability to sort results by a hash field if the field name is passed as a parameter to filter queries.
1 parent 5e845f2 commit 458be76

File tree

3 files changed

+116
-3
lines changed

3 files changed

+116
-3
lines changed

redisvl/query/query.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
return_fields: Optional[List[str]] = None,
137137
num_results: int = 10,
138138
dialect: int = 2,
139+
sort_by: Optional[str] = None,
139140
params: Optional[Dict[str, Any]] = None,
140141
):
141142
"""A query for a running a filtered search with a filter expression.
@@ -146,6 +147,8 @@ def __init__(
146147
return_fields (Optional[List[str]], optional): The fields to return.
147148
num_results (Optional[int], optional): The number of results to
148149
return. Defaults to 10.
150+
sort_by (Optional[str]): The field to order the results by. Defaults
151+
to None. Results will be ordered by vector distance.
149152
params (Optional[Dict[str, Any]], optional): The parameters for the
150153
query. Defaults to None.
151154
@@ -164,6 +167,7 @@ def __init__(
164167
"""
165168
super().__init__(return_fields, num_results, dialect)
166169
self.set_filter(filter_expression)
170+
self._sort_by = sort_by
167171
self._params = params or {}
168172

169173
@property
@@ -180,6 +184,8 @@ def query(self) -> Query:
180184
.paging(self._first, self._limit)
181185
.dialect(self._dialect)
182186
)
187+
if self._sort_by:
188+
query = query.sort_by(self._sort_by)
183189
return query
184190

185191

@@ -201,12 +207,14 @@ def __init__(
201207
num_results: int = 10,
202208
return_score: bool = True,
203209
dialect: int = 2,
210+
sort_by: Optional[str] = None,
204211
):
205212
super().__init__(return_fields, num_results, dialect)
206213
self.set_filter(filter_expression)
207214
self._vector = vector
208215
self._field = vector_field_name
209216
self._dtype = dtype.lower()
217+
self._sort_by = sort_by
210218

211219
if return_score:
212220
self._return_fields.append(self.DISTANCE_ID)
@@ -223,6 +231,7 @@ def __init__(
223231
num_results: int = 10,
224232
return_score: bool = True,
225233
dialect: int = 2,
234+
sort_by: Optional[str] = None,
226235
):
227236
"""A query for running a vector search along with an optional filter
228237
expression.
@@ -243,6 +252,8 @@ def __init__(
243252
distance. Defaults to True.
244253
dialect (int, optional): The RediSearch query dialect.
245254
Defaults to 2.
255+
sort_by (Optional[str]): The field to order the results by. Defaults
256+
to None. Results will be ordered by vector distance.
246257
247258
Raises:
248259
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -259,6 +270,7 @@ def __init__(
259270
num_results,
260271
return_score,
261272
dialect,
273+
sort_by,
262274
)
263275

264276
@property
@@ -272,10 +284,13 @@ def query(self) -> Query:
272284
query = (
273285
Query(base_query)
274286
.return_fields(*self._return_fields)
275-
.sort_by(self.DISTANCE_ID)
276287
.paging(self._first, self._limit)
277288
.dialect(self._dialect)
278289
)
290+
if self._sort_by:
291+
query = query.sort_by(self._sort_by)
292+
else:
293+
query = query.sort_by(self.DISTANCE_ID)
279294
return query
280295

281296
@property
@@ -307,6 +322,7 @@ def __init__(
307322
num_results: int = 10,
308323
return_score: bool = True,
309324
dialect: int = 2,
325+
sort_by: Optional[str] = None,
310326
):
311327
"""A query for running a filtered vector search based on semantic
312328
distance threshold.
@@ -330,7 +346,8 @@ def __init__(
330346
distance. Defaults to True.
331347
dialect (int, optional): The RediSearch query dialect.
332348
Defaults to 2.
333-
349+
sort_by (Optional[str]): The field to order the results by. Defaults
350+
to None. Results will be ordered by vector distance.
334351
Raises:
335352
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
336353
@@ -347,6 +364,7 @@ def __init__(
347364
num_results,
348365
return_score,
349366
dialect,
367+
sort_by,
350368
)
351369
self.set_distance_threshold(distance_threshold)
352370

@@ -390,10 +408,13 @@ def query(self) -> Query:
390408
query = (
391409
Query(base_query)
392410
.return_fields(*self._return_fields)
393-
.sort_by(self.DISTANCE_ID)
394411
.paging(self._first, self._limit)
395412
.dialect(self._dialect)
396413
)
414+
if self._sort_by:
415+
query = query.sort_by(self._sort_by)
416+
else:
417+
query = query.sort_by(self.DISTANCE_ID)
397418
return query
398419

399420
@property

tests/integration/test_query.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ def vector_query():
1818
)
1919

2020

21+
@pytest.fixture
22+
def sorted_vector_query():
23+
return VectorQuery(
24+
vector=[0.1, 0.1, 0.5],
25+
vector_field_name="user_embedding",
26+
return_fields=["user", "credit_score", "age", "job", "location"],
27+
sort_by="age",
28+
)
29+
30+
2131
@pytest.fixture
2232
def filter_query():
2333
return FilterQuery(
@@ -26,6 +36,15 @@ def filter_query():
2636
)
2737

2838

39+
@pytest.fixture
40+
def sorted_filter_query():
41+
return FilterQuery(
42+
return_fields=["user", "credit_score", "age", "job", "location"],
43+
filter_expression=Tag("credit_score") == "high",
44+
sort_by="age",
45+
)
46+
47+
2948
@pytest.fixture
3049
def range_query():
3150
return RangeQuery(
@@ -36,6 +55,17 @@ def range_query():
3655
)
3756

3857

58+
@pytest.fixture
59+
def sorted_range_query():
60+
return RangeQuery(
61+
vector=[0.1, 0.1, 0.5],
62+
vector_field_name="user_embedding",
63+
return_fields=["user", "credit_score", "age", "job", "location"],
64+
distance_threshold=0.2,
65+
sort_by="age",
66+
)
67+
68+
3969
@pytest.fixture
4070
def index(sample_data, redis_url):
4171
# construct a search index from the schema
@@ -160,6 +190,7 @@ def search(
160190
age_range=None,
161191
location=None,
162192
distance_threshold=0.2,
193+
sort=False,
163194
):
164195
"""Utility function to test filters."""
165196

@@ -199,6 +230,21 @@ def search(
199230
else:
200231
assert len(results.docs) == expected_count
201232

233+
# check results are in sorted order
234+
if sort:
235+
if isinstance(query, RangeQuery):
236+
assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100]
237+
else:
238+
assert [int(doc.age) for doc in results.docs] == [
239+
12,
240+
14,
241+
15,
242+
18,
243+
35,
244+
94,
245+
100,
246+
]
247+
202248

203249
@pytest.fixture(
204250
params=["vector_query", "filter_query", "range_query"],
@@ -339,3 +385,18 @@ def test_paginate_range_query(index, range_query):
339385
assert len(all_results) == expected_count
340386
assert i == expected_iterations
341387
assert all(float(item["vector_distance"]) <= 0.2 for item in all_results)
388+
389+
390+
def test_sort_filter_query(index, sorted_filter_query):
391+
t = Text("job") % ""
392+
search(sorted_filter_query, index, t, 7, sort=True)
393+
394+
395+
def test_sort_vector_query(index, sorted_vector_query):
396+
t = Text("job") % ""
397+
search(sorted_vector_query, index, t, 7, sort=True)
398+
399+
400+
def test_sort_range_query(index, sorted_range_query):
401+
t = Text("job") % ""
402+
search(sorted_range_query, index, t, 7, sort=True)

tests/unit/test_query_types.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_filter_query():
4545
assert isinstance(filter_query.params, dict)
4646
assert filter_query.params == {}
4747
assert filter_query._dialect == 2
48+
assert filter_query._sort_by == None
4849

4950
# Test set_filter functionality
5051
new_filter_expression = Tag("category") == "Sportswear"
@@ -57,6 +58,12 @@ def test_filter_query():
5758
assert filter_query._limit == 7
5859
assert filter_query._num_results == 10
5960

61+
# Test sort_by functionality
62+
filter_query = FilterQuery(
63+
filter_expression, return_fields, num_results=10, sort_by="price"
64+
)
65+
assert filter_query._sort_by == "price"
66+
6067

6168
def test_vector_query():
6269
# Create a vector query
@@ -73,6 +80,7 @@ def test_vector_query():
7380
assert isinstance(vector_query.params, dict)
7481
assert vector_query.params != {}
7582
assert vector_query._dialect == 3
83+
assert vector_query._sort_by == None
7684

7785
# Test set_filter functionality
7886
new_filter_expression = Tag("category") == "Sportswear"
@@ -85,6 +93,17 @@ def test_vector_query():
8593
assert vector_query._limit == 7
8694
assert vector_query._num_results == 10
8795

96+
# Test sort_by functionality
97+
vector_query = VectorQuery(
98+
sample_vector,
99+
"vector_field",
100+
["field1", "field2"],
101+
dialect=3,
102+
num_results=10,
103+
sort_by="field2",
104+
)
105+
assert vector_query._sort_by == "field2"
106+
88107

89108
def test_range_query():
90109
# Create a filter expression
@@ -104,6 +123,7 @@ def test_range_query():
104123
assert isinstance(range_query.query, Query)
105124
assert isinstance(range_query.params, dict)
106125
assert range_query.params != {}
126+
assert range_query._sort_by == None
107127

108128
# Test set_filter functionality
109129
new_filter_expression = Tag("category") == "Outdoor"
@@ -115,3 +135,14 @@ def test_range_query():
115135
assert range_query._first == 5
116136
assert range_query._limit == 7
117137
assert range_query._num_results == 10
138+
139+
# Test sort_by functionality
140+
range_query = RangeQuery(
141+
sample_vector,
142+
"vector_field",
143+
["field1"],
144+
filter_expression,
145+
num_results=10,
146+
sort_by="field1",
147+
)
148+
assert range_query._sort_by == "field1"

0 commit comments

Comments
 (0)