1
- import contextlib
2
1
import math
3
2
import warnings
4
3
from types import ModuleType
24
23
from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
25
24
from array_api_extra ._lib ._utils ._compat import device as get_device
26
25
from array_api_extra ._lib ._utils ._helpers import eager_shape , ndindex
27
- from array_api_extra ._lib ._utils ._typing import Array , Device
26
+ from array_api_extra ._lib ._utils ._typing import Device
28
27
from array_api_extra .testing import lazy_xp_function
29
28
30
29
# some xp backends are untyped
@@ -291,22 +290,12 @@ def test_xp(self, xp: ModuleType):
291
290
292
291
class TestExpandDims :
293
292
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
294
- @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "tuple index out of range" )
295
- @pytest .mark .xfail_xp_backend (Backend .TORCH , reason = "tuple index out of range" )
296
- def test_functionality (self , xp : ModuleType ):
297
- def _squeeze_all (b : Array ) -> Array :
298
- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
299
- for axis in range (b .ndim ):
300
- with contextlib .suppress (ValueError ):
301
- b = xp .squeeze (b , axis = axis )
302
- return b
303
-
304
- s = (2 , 3 , 4 , 5 )
305
- a = xp .empty (s )
293
+ def test_single_axis (self , xp : ModuleType ):
294
+ """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""
295
+ a = xp .empty ((2 , 3 , 4 , 5 ))
306
296
for axis in range (- 5 , 4 ):
307
297
b = expand_dims (a , axis = axis )
308
- assert b .shape [axis ] == 1
309
- assert _squeeze_all (b ).shape == s
298
+ xp_assert_equal (b , xp .expand_dims (a , axis = axis ))
310
299
311
300
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no expand_dims" )
312
301
def test_axis_tuple (self , xp : ModuleType ):
@@ -317,8 +306,7 @@ def test_axis_tuple(self, xp: ModuleType):
317
306
assert expand_dims (a , axis = (0 , - 3 , - 5 )).shape == (1 , 1 , 3 , 1 , 3 , 3 )
318
307
319
308
def test_axis_out_of_range (self , xp : ModuleType ):
320
- s = (2 , 3 , 4 , 5 )
321
- a = xp .empty (s )
309
+ a = xp .empty ((2 , 3 , 4 , 5 ))
322
310
with pytest .raises (IndexError , match = "out of bounds" ):
323
311
_ = expand_dims (a , axis = - 6 )
324
312
with pytest .raises (IndexError , match = "out of bounds" ):
0 commit comments