@@ -426,6 +426,22 @@ def test_tile(x, data):
426
426
def test_unstack (x , data ):
427
427
axis = data .draw (st .integers (min_value = - x .ndim , max_value = x .ndim - 1 ), label = "axis" )
428
428
kw = data .draw (hh .specified_kwargs (("axis" , axis , 0 )), label = "kw" )
429
- out = xp .asarray (xp .unstack (x , ** kw ), dtype = x .dtype )
430
- ph .assert_dtype ("unstack" , in_dtype = x .dtype , out_dtype = out .dtype )
431
- # TODO: shapes and values testing
429
+ out = xp .unstack (x , ** kw )
430
+
431
+ assert isinstance (out , tuple )
432
+ assert len (out ) == x .shape [axis ]
433
+ expected_shape = list (x .shape )
434
+ expected_shape .pop (axis )
435
+ expected_shape = tuple (expected_shape )
436
+ for i in range (x .shape [axis ]):
437
+ arr = out [i ]
438
+ ph .assert_result_shape ("unstack" , in_shapes = [x .shape ],
439
+ out_shape = arr .shape , expected = expected_shape ,
440
+ kw = kw , repr_name = f"out[{ i } ].shape" )
441
+
442
+ ph .assert_dtype ("unstack" , in_dtype = x .dtype , out_dtype = arr .dtype ,
443
+ repr_name = f"out[{ i } ].dtype" )
444
+
445
+ idx = [slice (None )] * x .ndim
446
+ idx [axis ] = i
447
+ ph .assert_array_elements ("unstack" , out = arr , expected = x [tuple (idx )], kw = kw , out_repr = f"out[{ i } ]" )
0 commit comments