Skip to content

Commit 8de7cdf

Browse files
committed
Always make axis=None for 0d arrays in argmin() and argmax()
1 parent cb2e7d0 commit 8de7cdf

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,10 @@
2121
data=st.data(),
2222
)
2323
def test_argmax(x, data):
24-
kw = data.draw(
25-
hh.kwargs(
26-
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
27-
keepdims=st.booleans(),
28-
),
29-
label="kw",
30-
)
24+
axis_strat = st.none()
25+
if x.ndim > 0:
26+
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
27+
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")
3128

3229
out = xp.argmax(x, **kw)
3330

@@ -56,13 +53,10 @@ def test_argmax(x, data):
5653
data=st.data(),
5754
)
5855
def test_argmin(x, data):
59-
kw = data.draw(
60-
hh.kwargs(
61-
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
62-
keepdims=st.booleans(),
63-
),
64-
label="kw",
65-
)
56+
axis_strat = st.none()
57+
if x.ndim > 0:
58+
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
59+
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")
6660

6761
out = xp.argmin(x, **kw)
6862

0 commit comments

Comments
 (0)