|
5 | 5 | import pytest
|
6 | 6 | from hypothesis import assume, given
|
7 | 7 | from hypothesis import strategies as st
|
| 8 | +from ndindex import iter_indices |
8 | 9 |
|
9 | 10 | from . import _array_module as xp
|
10 | 11 | from . import dtype_helpers as dh
|
|
17 | 18 |
|
18 | 19 |
|
19 | 20 | @pytest.mark.min_version("2023.12")
|
20 |
| -@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1))) |
21 |
| -def test_cumulative_sum(x): |
22 |
| - # TODO: test kwargs + diff shapes, adjust shape and values testing accordingly |
23 |
| - out = xp.cumulative_sum(x) |
24 |
| - # TODO: assert dtype |
25 |
| - ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape) |
26 |
| - # TODO: assert values |
| 21 | +@given( |
| 22 | + x=hh.arrays( |
| 23 | + dtype=xps.numeric_dtypes(), |
| 24 | + shape=hh.shapes(min_dims=1)), |
| 25 | + data=st.data(), |
| 26 | +) |
| 27 | +def test_cumulative_sum(x, data): |
| 28 | + axes = st.integers(-x.ndim, x.ndim - 1) |
| 29 | + if x.ndim == 1: |
| 30 | + axes = axes | st.none() |
| 31 | + axis = data.draw(axes, label='axis') |
| 32 | + _axis, = sh.normalise_axis(axis, x.ndim) |
| 33 | + dtype = data.draw(kwarg_dtypes(x.dtype)) |
| 34 | + include_initial = data.draw(st.booleans(), label="include_initial") |
| 35 | + |
| 36 | + kw = data.draw( |
| 37 | + hh.specified_kwargs( |
| 38 | + ("axis", axis, None), |
| 39 | + ("dtype", dtype, None), |
| 40 | + ("include_initial", include_initial, False), |
| 41 | + ), |
| 42 | + label="kw", |
| 43 | + ) |
| 44 | + |
| 45 | + out = xp.cumulative_sum(x, **kw) |
| 46 | + |
| 47 | + expected_shape = list(x.shape) |
| 48 | + if include_initial: |
| 49 | + expected_shape[_axis] += 1 |
| 50 | + expected_shape = tuple(expected_shape) |
| 51 | + ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape) |
| 52 | + |
| 53 | + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) |
| 54 | + if expected_dtype is None: |
| 55 | + # If a default uint cannot exist (i.e. in PyTorch which doesn't support |
| 56 | + # uint32 or uint64), we skip testing the output dtype. |
| 57 | + # See https://github.com/data-apis/array-api-tests/issues/106 |
| 58 | + if x.dtype in dh.uint_dtypes: |
| 59 | + assert dh.is_int_dtype(out.dtype) # sanity check |
| 60 | + else: |
| 61 | + ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) |
| 62 | + |
| 63 | + scalar_type = dh.get_scalar_type(out.dtype) |
| 64 | + |
| 65 | + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): |
| 66 | + x_arr = x[x_idx.raw] |
| 67 | + out_arr = out[out_idx.raw] |
| 68 | + |
| 69 | + if include_initial: |
| 70 | + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0) |
| 71 | + |
| 72 | + for n in range(x.shape[_axis]): |
| 73 | + start = 1 if include_initial else 0 |
| 74 | + out_val = out_arr[n + start] |
| 75 | + assume(cmath.isfinite(out_val)) |
| 76 | + elements = [] |
| 77 | + for idx in range(n + 1): |
| 78 | + s = scalar_type(x_arr[idx]) |
| 79 | + elements.append(s) |
| 80 | + expected = sum(elements) |
| 81 | + if dh.is_int_dtype(out.dtype): |
| 82 | + m, M = dh.dtype_ranges[out.dtype] |
| 83 | + assume(m <= expected <= M) |
| 84 | + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, |
| 85 | + idx=out_idx.raw, out=out_val, |
| 86 | + expected=expected) |
27 | 87 |
|
28 | 88 |
|
29 | 89 | def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
|
|
0 commit comments