Skip to content

Commit 605eec2

Browse files
committed
ENH: add testing of cumulative_prod
1 parent d649c0c commit 605eec2

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

array_api_tests/test_statistical_functions.py

+58
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,64 @@ def test_cumulative_sum(x, data):
9191
idx=out_idx.raw, out=out_val,
9292
expected=expected)
9393

94+
95+
96+
@pytest.mark.min_version("2024.12")
97+
@pytest.mark.unvectorized
98+
@given(
99+
x=hh.arrays(
100+
dtype=hh.numeric_dtypes,
101+
shape=hh.shapes(min_dims=1)),
102+
data=st.data(),
103+
)
104+
def test_cumulative_prod(x, data):
105+
axes = st.integers(-x.ndim, x.ndim - 1)
106+
if x.ndim == 1:
107+
axes = axes | st.none()
108+
axis = data.draw(axes, label='axis')
109+
_axis, = sh.normalize_axis(axis, x.ndim)
110+
dtype = data.draw(kwarg_dtypes(x.dtype))
111+
include_initial = data.draw(st.booleans(), label="include_initial")
112+
113+
kw = data.draw(
114+
hh.specified_kwargs(
115+
("axis", axis, None),
116+
("dtype", dtype, None),
117+
("include_initial", include_initial, False),
118+
),
119+
label="kw",
120+
)
121+
122+
out = xp.cumulative_prod(x, **kw)
123+
124+
expected_shape = list(x.shape)
125+
if include_initial:
126+
expected_shape[_axis] += 1
127+
expected_shape = tuple(expected_shape)
128+
ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape)
129+
130+
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
131+
if expected_dtype is None:
132+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
133+
# uint32 or uint64), we skip testing the output dtype.
134+
# See https://github.com/data-apis/array-api-tests/issues/106
135+
if x.dtype in dh.uint_dtypes:
136+
assert dh.is_int_dtype(out.dtype) # sanity check
137+
else:
138+
ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
139+
140+
scalar_type = dh.get_scalar_type(out.dtype)
141+
142+
for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
143+
x_arr = x[x_idx.raw]
144+
out_arr = out[out_idx.raw]
145+
146+
if include_initial:
147+
ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0)
148+
149+
#TODO: add value testing of cumulative_prod
150+
151+
94152
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
95153
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
96154
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]

0 commit comments

Comments
 (0)