Skip to content

Commit 6a51585

Browse files
committed
Skip sh.iter_indices() generation for 0-sided shapes
Also updates `test_logical_and`
1 parent c22efdf commit 6a51585

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

array_api_tests/shape_helpers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,21 @@ def normalise_axis(
7676
return axes
7777

7878

79-
def ndindex(shape):
80-
"""Yield every index of shape"""
79+
def ndindex(shape: Shape) -> Iterator[Index]:
80+
"""Yield every index of a shape"""
8181
return (indices[0] for indices in iter_indices(shape))
8282

8383

84-
def iter_indices(*shapes, skip_axes=()):
84+
def iter_indices(
85+
*shapes: Shape, skip_axes: Tuple[int, ...] = ()
86+
) -> Iterator[Tuple[Index, ...]]:
8587
"""Wrapper for ndindex.iter_indices()"""
86-
gen = _iter_indices(*shapes, skip_axes=skip_axes)
87-
return ([i.raw for i in indices] for indices in gen)
88+
# Prevent iterations if any shape has 0-sides
89+
for shape in shapes:
90+
if 0 in shape:
91+
return
92+
for indices in _iter_indices(*shapes, skip_axes=skip_axes):
93+
yield tuple(i.raw for i in indices) # type: ignore
8894

8995

9096
def axis_ndindex(

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,13 +1221,19 @@ def test_logaddexp(x1, x2):
12211221
def test_logical_and(x1, x2):
12221222
out = ah.logical_and(x1, x2)
12231223
ph.assert_dtype("logical_and", (x1.dtype, x2.dtype), out.dtype)
1224-
# See the comments in test_equal
1225-
shape = sh.broadcast_shapes(x1.shape, x2.shape)
1226-
ph.assert_shape("logical_and", out.shape, shape)
1227-
_x1 = xp.broadcast_to(x1, shape)
1228-
_x2 = xp.broadcast_to(x2, shape)
1229-
for idx in sh.ndindex(shape):
1230-
assert out[idx] == (bool(_x1[idx]) and bool(_x2[idx]))
1224+
ph.assert_result_shape("logical_and", (x1.shape, x2.shape), out.shape)
1225+
for l_idx, r_idx, o_idx in sh.iter_indices(x1.shape, x2.shape, out.shape):
1226+
scalar_l = bool(x1[l_idx])
1227+
scalar_r = bool(x2[r_idx])
1228+
expected = scalar_l and scalar_r
1229+
scalar_o = bool(out[o_idx])
1230+
f_l = sh.fmt_idx("x1", l_idx)
1231+
f_r = sh.fmt_idx("x2", r_idx)
1232+
f_o = sh.fmt_idx("out", o_idx)
1233+
assert scalar_o == expected, (
1234+
f"{f_o}={scalar_o}, but should be ({f_l} and {f_r})={expected} "
1235+
f"[logical_and()]\n{f_l}={scalar_l}, {f_r}={scalar_r}"
1236+
)
12311237

12321238

12331239
@given(xps.arrays(dtype=xp.bool, shape=hh.shapes()))

0 commit comments

Comments
 (0)