Skip to content

Commit 739730c

Browse files
committed
Wrap numpy and cupy nonzero to error on zero-dimensional arrays
1 parent 486ca51 commit 739730c

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

array_api_compat/common/_aliases.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,12 @@ def sort(
386386
res = xp.flip(res, axis=axis)
387387
return res
388388

389+
# nonzero should error for zero-dimensional arrays
390+
def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
391+
if x.ndim == 0:
392+
raise ValueError("nonzero() does not support zero-dimensional arrays")
393+
return xp.nonzero(x, **kwargs)
394+
389395
# sum() and prod() should always upcast when dtype=None
390396
def sum(
391397
x: ndarray,
@@ -526,5 +532,5 @@ def isdtype(
526532
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527533
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528534
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
529-
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
530-
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
535+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
536+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(cp)(_aliases.reshape)
5353
argsort = get_xp(cp)(_aliases.argsort)
5454
sort = get_xp(cp)(_aliases.sort)
55+
nonzero = get_xp(cp)(_aliases.nonzero)
5556
sum = get_xp(cp)(_aliases.sum)
5657
prod = get_xp(cp)(_aliases.prod)
5758
ceil = get_xp(cp)(_aliases.ceil)

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
reshape = get_xp(np)(_aliases.reshape)
5353
argsort = get_xp(np)(_aliases.argsort)
5454
sort = get_xp(np)(_aliases.sort)
55+
nonzero = get_xp(np)(_aliases.nonzero)
5556
sum = get_xp(np)(_aliases.sum)
5657
prod = get_xp(np)(_aliases.prod)
5758
ceil = get_xp(np)(_aliases.ceil)

0 commit comments

Comments
 (0)