Skip to content

Commit 0fed389

Browse files
committed
BUG: torch: fix count_nonzero with axis tuple and keepdims
1 parent 52e01be commit 0fed389

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,12 @@ def count_nonzero(
548548
) -> Array:
549549
result = torch.count_nonzero(x, dim=axis)
550550
if keepdims:
551-
if axis is not None:
551+
if isinstance(axis, int):
552552
return result.unsqueeze(axis)
553+
elif isinstance(axis, tuple):
554+
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
555+
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
556+
return torch.reshape(result, sh)
553557
return _axis_none_keepdims(result, x.ndim, keepdims)
554558
else:
555559
return result

0 commit comments

Comments
 (0)