Skip to content

Commit ca55dc9

Browse files
authored
Merge pull request #162 from crusaderky/expand_dims
TST: fix failures in `expand_dims` test
2 parents 308fc1f + 64378b7 commit ca55dc9

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

tests/test_funcs.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import math
32
import warnings
43
from types import ModuleType
@@ -24,7 +23,7 @@
2423
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2524
from array_api_extra._lib._utils._compat import device as get_device
2625
from array_api_extra._lib._utils._helpers import eager_shape, ndindex
27-
from array_api_extra._lib._utils._typing import Array, Device
26+
from array_api_extra._lib._utils._typing import Device
2827
from array_api_extra.testing import lazy_xp_function
2928

3029
# some xp backends are untyped
@@ -291,22 +290,12 @@ def test_xp(self, xp: ModuleType):
291290

292291
class TestExpandDims:
293292
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
294-
@pytest.mark.xfail_xp_backend(Backend.DASK, reason="tuple index out of range")
295-
@pytest.mark.xfail_xp_backend(Backend.TORCH, reason="tuple index out of range")
296-
def test_functionality(self, xp: ModuleType):
297-
def _squeeze_all(b: Array) -> Array:
298-
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
299-
for axis in range(b.ndim):
300-
with contextlib.suppress(ValueError):
301-
b = xp.squeeze(b, axis=axis)
302-
return b
303-
304-
s = (2, 3, 4, 5)
305-
a = xp.empty(s)
293+
def test_single_axis(self, xp: ModuleType):
294+
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
295+
a = xp.empty((2, 3, 4, 5))
306296
for axis in range(-5, 4):
307297
b = expand_dims(a, axis=axis)
308-
assert b.shape[axis] == 1
309-
assert _squeeze_all(b).shape == s
298+
xp_assert_equal(b, xp.expand_dims(a, axis=axis))
310299

311300
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
312301
def test_axis_tuple(self, xp: ModuleType):
@@ -317,8 +306,7 @@ def test_axis_tuple(self, xp: ModuleType):
317306
assert expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)
318307

319308
def test_axis_out_of_range(self, xp: ModuleType):
320-
s = (2, 3, 4, 5)
321-
a = xp.empty(s)
309+
a = xp.empty((2, 3, 4, 5))
322310
with pytest.raises(IndexError, match="out of bounds"):
323311
_ = expand_dims(a, axis=-6)
324312
with pytest.raises(IndexError, match="out of bounds"):

0 commit comments

Comments
 (0)