From e46e9789500407a7f9b6d129312fb6db360214c9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 24 Sep 2024 15:12:58 -0600 Subject: [PATCH] Add array and axis testing to repeat() Still need to add values testing. --- .../test_manipulation_functions.py | 53 ++++++++++++++----- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 28b54802..1566b768 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -287,6 +287,45 @@ def test_permute_dims(x, axes): out_indices=permuted_indices) +@pytest.mark.min_version("2023.12") +@given( + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)), + kw=hh.kwargs( + axis=st.none() | shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), + data=st.data(), +) +def test_repeat(x, kw, data): + shape = x.shape + axis = kw.get("axis", None) + dim = math.prod(shape) if axis is None else shape[axis] + repeat_strat = st.integers(1, 4) + repeats = data.draw(repeat_strat + | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat, + shape=st.sampled_from([(1,), (dim,)])), + label="repeats") + if isinstance(repeats, int): + n_repitions = dim*repeats + else: + if repeats.shape == (1,): + n_repitions = dim*repeats[0] + else: + n_repitions = int(xp.sum(repeats)) + + out = xp.repeat(x, repeats, **kw) + ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) + if axis is None: + expected_shape = (n_repitions,) + else: + expected_shape = list(shape) + expected_shape[axis] = n_repitions + expected_shape = tuple(expected_shape) + ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) + # TODO: values testing + + @st.composite def reshape_shapes(draw, shape): size = 1 if len(shape) == 0 else math.prod(shape) @@ -298,20 +337,6 @@ def reshape_shapes(draw, shape): return tuple(rshape) -@pytest.mark.min_version("2023.12") -@given( - x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), - repeats=st.integers(1, 4), -) -def test_repeat(x, repeats): - # TODO: test array repeats and non-None axis, adjust shape and value testing accordingly - out = xp.repeat(x, repeats) - ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) - expected_shape = (math.prod(x.shape) * repeats,) - ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) - # TODO: values testing - - @pytest.mark.unvectorized @pytest.mark.skip("flaky") # TODO: fix! @given(