|
7 | 7 | from array_api_extra._lib import Backend
|
8 | 8 | from array_api_extra._lib._testing import xp_assert_equal
|
9 | 9 | from array_api_extra._lib._utils._compat import device as get_device
|
10 |
| -from array_api_extra._lib._utils._helpers import asarrays, in1d, ndindex |
| 10 | +from array_api_extra._lib._utils._helpers import asarrays, eager_shape, in1d, ndindex |
11 | 11 | from array_api_extra._lib._utils._typing import Array, Device, DType
|
12 | 12 | from array_api_extra.testing import lazy_xp_function
|
13 | 13 |
|
@@ -156,3 +156,20 @@ def test_numpy_generics(self, dtype: DType):
|
156 | 156 | )
|
157 | 157 | def test_ndindex(shape: tuple[int, ...]):
|
158 | 158 | assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))
|
| 159 | + |
| 160 | + |
| 161 | +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") |
| 162 | +def test_eager_shape(xp: ModuleType, library: Backend): |
| 163 | + a = xp.asarray([1, 2, 3]) |
| 164 | + # Lazy arrays, like Dask, have an eager shape until you slice them with |
| 165 | + # a lazy boolean mask |
| 166 | + assert eager_shape(a) == a.shape == (3,) |
| 167 | + |
| 168 | + b = a[a > 2] |
| 169 | + if library is Backend.DASK: |
| 170 | + with pytest.raises(TypeError, match="Unsupported lazy shape"): |
| 171 | + _ = eager_shape(b) |
| 172 | + # FIXME can't test use case for None in the shape until we add support for |
| 173 | + # other lazy backends |
| 174 | + else: |
| 175 | + assert eager_shape(b) == b.shape == (1,) |
0 commit comments