Skip to content

Commit 6463307

Browse files
committed
Use xps.real_dtypes() in tests for comparison functions
1 parent 4a1c0a1 commit 6463307

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@given(
1818
x=xps.arrays(
19-
dtype=xps.numeric_dtypes(),
19+
dtype=xps.real_dtypes(),
2020
shape=hh.shapes(min_dims=1, min_side=1),
2121
elements={"allow_nan": False},
2222
),
@@ -53,7 +53,7 @@ def test_argmax(x, data):
5353

5454
@given(
5555
x=xps.arrays(
56-
dtype=xps.numeric_dtypes(),
56+
dtype=xps.real_dtypes(),
5757
shape=hh.shapes(min_dims=1, min_side=1),
5858
elements={"allow_nan": False},
5959
),

array_api_tests/test_sorting_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def assert_scalar_in_set(
3434
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
3535
@given(
3636
x=xps.arrays(
37-
dtype=xps.scalar_dtypes(),
37+
dtype=xps.real_dtypes(),
3838
shape=hh.shapes(min_dims=1, min_side=1),
3939
elements={"allow_nan": False},
4040
),
@@ -94,7 +94,7 @@ def test_argsort(x, data):
9494
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
9595
@given(
9696
x=xps.arrays(
97-
dtype=xps.scalar_dtypes(),
97+
dtype=xps.real_dtypes(),
9898
shape=hh.shapes(min_dims=1, min_side=1),
9999
elements={"allow_nan": False},
100100
),

array_api_tests/test_statistical_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2727

2828
@given(
2929
x=xps.arrays(
30-
dtype=xps.numeric_dtypes(),
30+
dtype=xps.real_dtypes(),
3131
shape=hh.shapes(min_side=1),
3232
elements={"allow_nan": False},
3333
),
@@ -79,7 +79,7 @@ def test_mean(x, data):
7979

8080
@given(
8181
x=xps.arrays(
82-
dtype=xps.numeric_dtypes(),
82+
dtype=xps.real_dtypes(),
8383
shape=hh.shapes(min_side=1),
8484
elements={"allow_nan": False},
8585
),

0 commit comments

Comments
 (0)