Skip to content

Commit 8a50ebc

Browse files
committed
Add shape, dtype, and value testing for cumulative_sum
1 parent c8c9498 commit 8a50ebc

File tree

1 file changed

+67
-7
lines changed

1 file changed

+67
-7
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from hypothesis import assume, given
77
from hypothesis import strategies as st
8+
from ndindex import iter_indices
89

910
from . import _array_module as xp
1011
from . import dtype_helpers as dh
@@ -17,13 +18,72 @@
1718

1819

1920
@pytest.mark.min_version("2023.12")
20-
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
21-
def test_cumulative_sum(x):
22-
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
23-
out = xp.cumulative_sum(x)
24-
# TODO: assert dtype
25-
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
26-
# TODO: assert values
21+
@given(
22+
x=hh.arrays(
23+
dtype=xps.numeric_dtypes(),
24+
shape=hh.shapes(min_dims=1)),
25+
data=st.data(),
26+
)
27+
def test_cumulative_sum(x, data):
28+
axes = st.integers(-x.ndim, x.ndim - 1)
29+
if x.ndim == 1:
30+
axes = axes | st.none()
31+
axis = data.draw(axes, label='axis')
32+
_axis, = sh.normalise_axis(axis, x.ndim)
33+
dtype = data.draw(kwarg_dtypes(x.dtype))
34+
include_initial = data.draw(st.booleans(), label="include_initial")
35+
36+
kw = data.draw(
37+
hh.specified_kwargs(
38+
("axis", axis, None),
39+
("dtype", dtype, None),
40+
("include_initial", include_initial, False),
41+
),
42+
label="kw",
43+
)
44+
45+
out = xp.cumulative_sum(x, **kw)
46+
47+
expected_shape = list(x.shape)
48+
if include_initial:
49+
expected_shape[_axis] += 1
50+
expected_shape = tuple(expected_shape)
51+
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape)
52+
53+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
54+
if expected_dtype is None:
55+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
56+
# uint32 or uint64), we skip testing the output dtype.
57+
# See https://github.com/data-apis/array-api-tests/issues/106
58+
if x.dtype in dh.uint_dtypes:
59+
assert dh.is_int_dtype(out.dtype) # sanity check
60+
else:
61+
ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
62+
63+
scalar_type = dh.get_scalar_type(out.dtype)
64+
65+
for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
66+
x_arr = x[x_idx.raw]
67+
out_arr = out[out_idx.raw]
68+
69+
if include_initial:
70+
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0)
71+
72+
for n in range(x.shape[_axis]):
73+
start = 1 if include_initial else 0
74+
out_val = out_arr[n + start]
75+
assume(cmath.isfinite(out_val))
76+
elements = []
77+
for idx in range(n + 1):
78+
s = scalar_type(x_arr[idx])
79+
elements.append(s)
80+
expected = sum(elements)
81+
if dh.is_int_dtype(out.dtype):
82+
m, M = dh.dtype_ranges[out.dtype]
83+
assume(m <= expected <= M)
84+
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
85+
idx=out_idx.raw, out=out_val,
86+
expected=expected)
2787

2888

2989
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:

0 commit comments

Comments
 (0)