Skip to content

Commit 6362204

Browse files
committed
Add missing tests for unstack()
1 parent a04ff8f commit 6362204

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

array_api_tests/test_manipulation_functions.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,22 @@ def test_tile(x, data):
426426
def test_unstack(x, data):
427427
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
428428
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
429-
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
430-
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
431-
# TODO: shapes and values testing
429+
out = xp.unstack(x, **kw)
430+
431+
assert isinstance(out, tuple)
432+
assert len(out) == x.shape[axis]
433+
expected_shape = list(x.shape)
434+
expected_shape.pop(axis)
435+
expected_shape = tuple(expected_shape)
436+
for i in range(x.shape[axis]):
437+
arr = out[i]
438+
ph.assert_result_shape("unstack", in_shapes=[x.shape],
439+
out_shape=arr.shape, expected=expected_shape,
440+
kw=kw, repr_name=f"out[{i}].shape")
441+
442+
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
443+
repr_name=f"out[{i}].dtype")
444+
445+
idx = [slice(None)] * x.ndim
446+
idx[axis] = i
447+
ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")

0 commit comments

Comments
 (0)