Skip to content

Commit 4a0d975

Browse files
committed
Fix the searchsorted test (and add a TODO)
1 parent 9aad419 commit 4a0d975

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

array_api_tests/dtype_helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
315315
)
316316
else:
317317
default_complex = None
318+
318319
if dtype_nbits[default_int] == 32:
319320
default_uint = getattr(xp, "uint32", None)
320321
else:

array_api_tests/test_searching_functions.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,14 @@ def test_where(shapes, dtypes, data):
173173
@given(data=st.data())
174174
def test_searchsorted(data):
175175
# TODO: test side="right"
176+
# TODO: Allow different dtypes for x1 and x2
176177
_x1 = data.draw(
177178
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
178179
label="_x1",
179180
)
180181
x1 = xp.asarray(_x1, dtype=dh.default_float)
181182
if data.draw(st.booleans(), label="use sorter?"):
182-
sorter = data.draw(
183-
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
184-
label="sorter",
185-
)
183+
sorter = xp.argsort(x1)
186184
else:
187185
sorter = None
188186
x1 = xp.sort(x1)
@@ -202,4 +200,4 @@ def test_searchsorted(data):
202200
out_dtype=out.dtype,
203201
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
204202
)
205-
# TODO: shapes and values testing
203+
# TODO: shapes and values testing

0 commit comments

Comments
 (0)