diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1566b768..b8a919c4 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -300,31 +300,53 @@ def test_permute_dims(x, axes): def test_repeat(x, kw, data): shape = x.shape axis = kw.get("axis", None) - dim = math.prod(shape) if axis is None else shape[axis] - repeat_strat = st.integers(1, 4) + size = math.prod(shape) if axis is None else shape[axis] + repeat_strat = st.integers(1, 10) repeats = data.draw(repeat_strat | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat, - shape=st.sampled_from([(1,), (dim,)])), + shape=st.sampled_from([(1,), (size,)])), label="repeats") if isinstance(repeats, int): - n_repitions = dim*repeats + n_repititions = size*repeats else: if repeats.shape == (1,): - n_repitions = dim*repeats[0] + n_repititions = size*int(repeats[0]) else: - n_repitions = int(xp.sum(repeats)) + n_repititions = int(xp.sum(repeats)) + + assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE) out = xp.repeat(x, repeats, **kw) ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) if axis is None: - expected_shape = (n_repitions,) + expected_shape = (n_repititions,) else: expected_shape = list(shape) - expected_shape[axis] = n_repitions + expected_shape[axis] = n_repititions expected_shape = tuple(expected_shape) ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) - # TODO: values testing + # Test values + + if isinstance(repeats, int): + repeats_array = xp.full(size, repeats, dtype=xp.int32) + else: + repeats_array = repeats + + if kw.get("axis") is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + for idx, in sh.iter_indices(x.shape, skip_axes=axis): + x_slice = x[idx] + out_slice = out[idx] + start = 0 + for i, count in enumerate(repeats_array): + end = start + count + ph.assert_array_elements("repeat", out=out_slice[start:end], + expected=xp.full((count,), x_slice[i], dtype=x.dtype), + kw=kw) + start = end @st.composite def reshape_shapes(draw, shape):