Skip to content

Commit 34faced

Browse files
committed
Fix torch.linalg.vector_norm for axis=()
1 parent 376038e commit 34faced

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

array_api_compat/torch/linalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,23 @@ def vector_norm(
7878
) -> array:
7979
# torch.vector_norm incorrectly treats axis=() the same as axis=None
8080
if axis == ():
81-
keepdims = True
81+
out = kwargs.get('out')
82+
if out is None:
83+
dtype = None
84+
if x.dtype == torch.complex64:
85+
dtype = torch.float32
86+
elif x.dtype == torch.complex128:
87+
dtype = torch.float64
88+
89+
out = torch.zeros_like(x, dtype=dtype)
90+
91+
# The norm of a single scalar works out to abs(x) in every case except
92+
# for ord=0, which is x != 0.
93+
if ord == 0:
94+
out[:] = (x != 0)
95+
else:
96+
out[:] = torch.abs(x)
97+
return out
8298
return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
8399

84100
__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',

0 commit comments

Comments
 (0)