Skip to content

Commit bee123a

Browse files
committed
Updates to op/elwise tests
* Use `xps` dtype strategies where possible for better repr and dtype filtering * Helper for asserting broadcasted shapes * Helper `sh.iter_indices()` to wrap `ndindex` equivalent * Update `test_equal` with `sh.iter_indices()`
1 parent 693c29b commit bee123a

File tree

2 files changed

+112
-80
lines changed

2 files changed

+112
-80
lines changed

array_api_tests/shape_helpers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from itertools import product
33
from typing import Iterator, List, Optional, Tuple, Union
44

5+
from ndindex import iter_indices as _iter_indices
6+
57
from .typing import Scalar, Shape
68

79
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
@@ -18,12 +20,14 @@ def normalise_axis(
1820

1921

2022
def ndindex(shape):
21-
"""Iterator of n-D indices to an array
23+
# TODO: remove
24+
return (indices[0] for indices in iter_indices(shape))
25+
2226

23-
Yields tuples of integers to index every element of an array of shape
24-
`shape`. Same as np.ndindex().
25-
"""
26-
return product(*[range(i) for i in shape])
27+
def iter_indices(*shapes, skip_axes=()):
28+
"""Wrapper for ndindex.iter_indices()"""
29+
gen = _iter_indices(*shapes, skip_axes=skip_axes)
30+
return ([i.raw for i in indices] for indices in gen)
2731

2832

2933
def axis_ndindex(

0 commit comments

Comments
 (0)