Skip to content

Commit 5ffffee

Browse files
authored
feat: add count_nonzero to specification
PR-URL: #803 Closes: #794
1 parent 5cdcf75 commit 5ffffee

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

spec/draft/API_specification/searching_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Objects in API
2222

2323
argmax
2424
argmin
25+
count_nonzero
2526
nonzero
2627
searchsorted
2728
where

src/array_api_stubs/_draft/searching_functions.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
__all__ = ["argmax", "argmin", "nonzero", "searchsorted", "where"]
1+
__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"]
22

33

4-
from ._types import Optional, Tuple, Literal, array
4+
from ._types import Optional, Tuple, Literal, Union, array
55

66

77
def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
@@ -54,15 +54,41 @@ def argmin(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -
5454
"""
5555

5656

57-
def nonzero(x: array, /) -> Tuple[array, ...]:
57+
def count_nonzero(
58+
x: array,
59+
/,
60+
*,
61+
axis: Optional[Union[int, Tuple[int, ...]]] = None,
62+
keepdims: bool = False,
63+
) -> array:
5864
"""
59-
Returns the indices of the array elements which are non-zero.
65+
Counts the number of array elements which are non-zero.
6066
61-
.. note::
62-
If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
67+
Parameters
68+
----------
69+
x: array
70+
input array.
71+
axis: Optional[Union[int, Tuple[int, ...]]]
72+
axis or axes along which to count non-zero values. By default, the number of non-zero values must be computed over the entire array. If a tuple of integers, the number of non-zero values must be computed over multiple axes. Default: ``None``.
73+
keepdims: bool
74+
if ``True``, the reduced axes (dimensions) must be included in the result as singleton dimensions, and, accordingly, the result must be compatible with the input array (see :ref:`broadcasting`). Otherwise, if ``False``, the reduced axes (dimensions) must not be included in the result. Default: ``False``.
6375
64-
.. note::
65-
If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.
76+
Returns
77+
-------
78+
out: array
79+
if the number of non-zeros values was computed over the entire array, a zero-dimensional array containing the total number of non-zero values; otherwise, a non-zero-dimensional array containing the counts along the specified axes. The returned array must have the default array index data type.
80+
81+
Notes
82+
-----
83+
84+
- If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
85+
- If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.
86+
"""
87+
88+
89+
def nonzero(x: array, /) -> Tuple[array, ...]:
90+
"""
91+
Returns the indices of the array elements which are non-zero.
6692
6793
.. admonition:: Data-dependent output shape
6894
:class: admonition important
@@ -76,12 +102,15 @@ def nonzero(x: array, /) -> Tuple[array, ...]:
76102
77103
Returns
78104
-------
79-
out: Typle[array, ...]
105+
out: Tuple[array, ...]
80106
a tuple of ``k`` arrays, one for each dimension of ``x`` and each of size ``n`` (where ``n`` is the total number of non-zero elements), containing the indices of the non-zero elements in that dimension. The indices must be returned in row-major, C-style order. The returned array must have the default array index data type.
81107
82108
Notes
83109
-----
84110
111+
- If ``x`` has a complex floating-point data type, non-zero elements are those elements having at least one component (real or imaginary) which is non-zero.
112+
- If ``x`` has a boolean data type, non-zero elements are those elements which are equal to ``True``.
113+
85114
.. versionchanged:: 2022.12
86115
Added complex data type support.
87116
"""

0 commit comments

Comments
 (0)