Skip to content

Commit 4f63331

Browse files
authored
Merge pull request #113 from ev-br/count_nonzero
Add `count_nonzero` and `cumulative_prod` from 2024.12 revision draft
2 parents 4847400 + 61bf3c1 commit 4f63331

File tree

4 files changed

+58
-8
lines changed

4 files changed

+58
-8
lines changed

array_api_strict/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,9 @@
293293

294294
__all__ += ["concat", "expand_dims", "flip", "moveaxis", "permute_dims", "repeat", "reshape", "roll", "squeeze", "stack", "tile", "unstack"]
295295

296-
from ._searching_functions import argmax, argmin, nonzero, searchsorted, where
296+
from ._searching_functions import argmax, argmin, nonzero, count_nonzero, searchsorted, where
297297

298-
__all__ += ["argmax", "argmin", "nonzero", "searchsorted", "where"]
298+
__all__ += ["argmax", "argmin", "nonzero", "count_nonzero", "searchsorted", "where"]
299299

300300
from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
301301

@@ -305,9 +305,9 @@
305305

306306
__all__ += ["argsort", "sort"]
307307

308-
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
308+
from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var
309309

310-
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
310+
__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"]
311311

312312
from ._utility_functions import all, any, diff
313313

array_api_strict/_searching_functions.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Literal, Optional, Tuple
9+
from typing import Literal, Optional, Tuple, Union
1010

1111
import numpy as np
1212

@@ -45,6 +45,24 @@ def nonzero(x: Array, /) -> Tuple[Array, ...]:
4545
raise ValueError("nonzero is not allowed on 0-dimensional arrays")
4646
return tuple(Array._new(i, device=x.device) for i in np.nonzero(x._array))
4747

48+
49+
@requires_api_version('2024.12')
50+
def count_nonzero(
51+
x: Array,
52+
/,
53+
*,
54+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
55+
keepdims: bool = False,
56+
) -> Array:
57+
"""
58+
Array API compatible wrapper for :py:func:`np.count_nonzero <numpy.count_nonzero>`
59+
60+
See its docstring for more information.
61+
"""
62+
arr = np.count_nonzero(x._array, axis=axis, keepdims=keepdims)
63+
return Array._new(np.asarray(arr), device=x.device)
64+
65+
4866
@requires_api_version('2023.12')
4967
def searchsorted(
5068
x1: Array,

array_api_strict/_statistical_functions.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ._array_object import Array
1010
from ._dtypes import float32, complex64
1111
from ._flags import requires_api_version, get_array_api_strict_flags
12-
from ._creation_functions import zeros
12+
from ._creation_functions import zeros, ones
1313
from ._manipulation_functions import concat
1414

1515
from typing import TYPE_CHECKING
@@ -31,7 +31,6 @@ def cumulative_sum(
3131
) -> Array:
3232
if x.dtype not in _numeric_dtypes:
3333
raise TypeError("Only numeric dtypes are allowed in cumulative_sum")
34-
dt = x.dtype if dtype is None else dtype
3534
if dtype is not None:
3635
dtype = dtype._np_dtype
3736

@@ -44,9 +43,40 @@ def cumulative_sum(
4443
if include_initial:
4544
if axis < 0:
4645
axis += x.ndim
47-
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
46+
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
4847
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
4948

49+
50+
@requires_api_version('2024.12')
51+
def cumulative_prod(
52+
x: Array,
53+
/,
54+
*,
55+
axis: Optional[int] = None,
56+
dtype: Optional[Dtype] = None,
57+
include_initial: bool = False,
58+
) -> Array:
59+
if x.dtype not in _numeric_dtypes:
60+
raise TypeError("Only numeric dtypes are allowed in cumulative_prod")
61+
if x.ndim == 0:
62+
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")
63+
64+
if dtype is not None:
65+
dtype = dtype._np_dtype
66+
67+
if axis is None:
68+
if x.ndim > 1:
69+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
70+
axis = 0
71+
72+
# np.cumprod does not support include_initial
73+
if include_initial:
74+
if axis < 0:
75+
axis += x.ndim
76+
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
77+
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device)
78+
79+
5080
def max(
5181
x: Array,
5282
/,

array_api_strict/tests/test_flags.py

+2
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def test_api_version_2023_12(func_name):
307307
'reciprocal': lambda: xp.reciprocal(xp.asarray([2.])),
308308
'take_along_axis': lambda: xp.take_along_axis(xp.zeros((2, 3)),
309309
xp.zeros((1, 4), dtype=xp.int64)),
310+
'count_nonzero': lambda: xp.count_nonzero(xp.arange(3)),
311+
'cumulative_prod': lambda: xp.cumulative_prod(xp.arange(1, 5)),
310312
}
311313

312314
@pytest.mark.parametrize('func_name', api_version_2024_12_examples.keys())

0 commit comments

Comments
 (0)