From 866868f5ec3147782a07cd67352c9ad6c99f6453 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Oct 2024 13:43:23 -0600 Subject: [PATCH 1/5] Use a better variable name --- array_api_tests/test_manipulation_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1566b768..72ecd855 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -300,17 +300,17 @@ def test_permute_dims(x, axes): def test_repeat(x, kw, data): shape = x.shape axis = kw.get("axis", None) - dim = math.prod(shape) if axis is None else shape[axis] + size = math.prod(shape) if axis is None else shape[axis] repeat_strat = st.integers(1, 4) repeats = data.draw(repeat_strat | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat, - shape=st.sampled_from([(1,), (dim,)])), + shape=st.sampled_from([(1,), (size,)])), label="repeats") if isinstance(repeats, int): - n_repitions = dim*repeats + n_repitions = size*repeats else: if repeats.shape == (1,): - n_repitions = dim*repeats[0] + n_repitions = size*repeats[0] else: n_repitions = int(xp.sum(repeats)) From 89a6addf999d715960648d1a2e110d77d2ea8b19 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Oct 2024 14:58:35 -0600 Subject: [PATCH 2/5] Make sure a scalar value is a Python scalar --- array_api_tests/test_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 72ecd855..9b89b0b0 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -310,7 +310,7 @@ def test_repeat(x, kw, data): n_repitions = size*repeats else: if repeats.shape == (1,): - n_repitions = size*repeats[0] + n_repitions = size*int(repeats[0]) else: n_repitions = int(xp.sum(repeats)) From d218c3602df36ce865e5f3b381e332c9b8e45c66 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Oct 2024 14:59:15 -0600 Subject: [PATCH 3/5] Fix spelling of a variable name --- array_api_tests/test_manipulation_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 9b89b0b0..dba82eea 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -307,20 +307,20 @@ def test_repeat(x, kw, data): shape=st.sampled_from([(1,), (size,)])), label="repeats") if isinstance(repeats, int): - n_repitions = size*repeats + n_repititions = size*repeats else: if repeats.shape == (1,): - n_repitions = size*int(repeats[0]) + n_repititions = size*int(repeats[0]) else: - n_repitions = int(xp.sum(repeats)) + n_repititions = int(xp.sum(repeats)) out = xp.repeat(x, repeats, **kw) ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) if axis is None: - expected_shape = (n_repitions,) + expected_shape = (n_repititions,) else: expected_shape = list(shape) - expected_shape[axis] = n_repitions + expected_shape[axis] = n_repititions expected_shape = tuple(expected_shape) ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) # TODO: values testing From c0c6ba95baba6b0562674ee8f58db54e954356e4 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Oct 2024 14:59:25 -0600 Subject: [PATCH 4/5] Limit the repititions by the total array size in test_repeat --- array_api_tests/test_manipulation_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index dba82eea..1281f0ce 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -301,7 +301,7 @@ def test_repeat(x, kw, data): shape = x.shape axis = kw.get("axis", None) size = math.prod(shape) if axis is None else shape[axis] - repeat_strat = st.integers(1, 4) + repeat_strat = st.integers(1, 10) repeats = data.draw(repeat_strat | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat, shape=st.sampled_from([(1,), (size,)])), @@ -314,6 +314,8 @@ def test_repeat(x, kw, data): else: n_repititions = int(xp.sum(repeats)) + assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE) + out = xp.repeat(x, repeats, **kw) ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) if axis is None: From 89d4ac0c001eb883a11e0f8e55fc01e45a182d78 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 14 Oct 2024 14:59:48 -0600 Subject: [PATCH 5/5] Add values testing to test_repeat --- .../test_manipulation_functions.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index 1281f0ce..b8a919c4 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -325,8 +325,28 @@ def test_repeat(x, kw, data): expected_shape[axis] = n_repititions expected_shape = tuple(expected_shape) ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) - # TODO: values testing + # Test values + + if isinstance(repeats, int): + repeats_array = xp.full(size, repeats, dtype=xp.int32) + else: + repeats_array = repeats + + if kw.get("axis") is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + for idx, in sh.iter_indices(x.shape, skip_axes=axis): + x_slice = x[idx] + out_slice = out[idx] + start = 0 + for i, count in enumerate(repeats_array): + end = start + count + ph.assert_array_elements("repeat", out=out_slice[start:end], + expected=xp.full((count,), x_slice[i], dtype=x.dtype), + kw=kw) + start = end @st.composite def reshape_shapes(draw, shape):