Skip to content

Commit de35948

Browse files
committed
Swap {in/out}_shape params in ph.assert_keepdimable_shape()
Also documents it
1 parent d8c25e3 commit de35948

File tree

4 files changed

+27
-12
lines changed

4 files changed

+27
-12
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,28 @@ def assert_result_shape(
264264

265265
def assert_keepdimable_shape(
266266
func_name: str,
267-
out_shape: Shape,
268267
in_shape: Shape,
268+
out_shape: Shape,
269269
axes: Tuple[int, ...],
270270
keepdims: bool,
271271
/,
272272
**kw,
273273
):
274+
"""
275+
Assert the output shape from a keepdimable function is as expected, e.g.
276+
277+
>>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
278+
>>> out1 = xp.max(x, keepdims=False)
279+
>>> out2 = xp.max(x, keepdims=True)
280+
>>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False)
281+
>>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True)
282+
283+
is equivalent to
284+
285+
>>> assert out1.shape == ()
286+
>>> assert out2.shape == (1, 1)
287+
288+
"""
274289
if keepdims:
275290
shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
276291
else:

array_api_tests/test_searching_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_argmax(x, data):
3434
ph.assert_default_index("argmax", out.dtype)
3535
axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
3636
ph.assert_keepdimable_shape(
37-
"argmax", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
37+
"argmax", x.shape, out.shape, axes, kw.get("keepdims", False), **kw
3838
)
3939
scalar_type = dh.get_scalar_type(x.dtype)
4040
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
@@ -69,7 +69,7 @@ def test_argmin(x, data):
6969
ph.assert_default_index("argmin", out.dtype)
7070
axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
7171
ph.assert_keepdimable_shape(
72-
"argmin", out.shape, x.shape, axes, kw.get("keepdims", False), **kw
72+
"argmin", x.shape, out.shape, axes, kw.get("keepdims", False), **kw
7373
)
7474
scalar_type = dh.get_scalar_type(x.dtype)
7575
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):

array_api_tests/test_statistical_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_max(x, data):
3838
ph.assert_dtype("max", x.dtype, out.dtype)
3939
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
4040
ph.assert_keepdimable_shape(
41-
"max", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
41+
"max", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
4242
)
4343
scalar_type = dh.get_scalar_type(out.dtype)
4444
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
@@ -67,7 +67,7 @@ def test_mean(x, data):
6767
ph.assert_dtype("mean", x.dtype, out.dtype)
6868
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
6969
ph.assert_keepdimable_shape(
70-
"mean", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
70+
"mean", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
7171
)
7272
# Values testing mean is too finicky
7373

@@ -88,7 +88,7 @@ def test_min(x, data):
8888
ph.assert_dtype("min", x.dtype, out.dtype)
8989
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
9090
ph.assert_keepdimable_shape(
91-
"min", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
91+
"min", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
9292
)
9393
scalar_type = dh.get_scalar_type(out.dtype)
9494
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
@@ -147,7 +147,7 @@ def test_prod(x, data):
147147
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
148148
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
149149
ph.assert_keepdimable_shape(
150-
"prod", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
150+
"prod", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
151151
)
152152
scalar_type = dh.get_scalar_type(out.dtype)
153153
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
@@ -194,7 +194,7 @@ def test_std(x, data):
194194

195195
ph.assert_dtype("std", x.dtype, out.dtype)
196196
ph.assert_keepdimable_shape(
197-
"std", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
197+
"std", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
198198
)
199199
# We can't easily test the result(s) as standard deviation methods vary a lot
200200

@@ -245,7 +245,7 @@ def test_sum(x, data):
245245
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
246246
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
247247
ph.assert_keepdimable_shape(
248-
"sum", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
248+
"sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
249249
)
250250
scalar_type = dh.get_scalar_type(out.dtype)
251251
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
@@ -292,6 +292,6 @@ def test_var(x, data):
292292

293293
ph.assert_dtype("var", x.dtype, out.dtype)
294294
ph.assert_keepdimable_shape(
295-
"var", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
295+
"var", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
296296
)
297297
# We can't easily test the result(s) as variance methods vary a lot

array_api_tests/test_utility_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_all(x, data):
2424
ph.assert_dtype("all", x.dtype, out.dtype, xp.bool)
2525
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
2626
ph.assert_keepdimable_shape(
27-
"all", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
27+
"all", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
2828
)
2929
scalar_type = dh.get_scalar_type(x.dtype)
3030
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
@@ -49,7 +49,7 @@ def test_any(x, data):
4949
ph.assert_dtype("any", x.dtype, out.dtype, xp.bool)
5050
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
5151
ph.assert_keepdimable_shape(
52-
"any", out.shape, x.shape, _axes, kw.get("keepdims", False), **kw
52+
"any", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
5353
)
5454
scalar_type = dh.get_scalar_type(x.dtype)
5555
for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):

0 commit comments

Comments
 (0)