Skip to content

Commit f559469

Browse files
lucascolleyrgommers
authored andcommitted
BUG: fix cholesky upper decomp for complex dtypes
1 parent 874c2ff commit f559469

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

array_api_compat/common/_linalg.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot
14+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -59,7 +59,10 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult
5959
def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray:
6060
L = xp.linalg.cholesky(x, **kwargs)
6161
if upper:
62-
return get_xp(xp)(matrix_transpose)(L)
62+
U = get_xp(xp)(matrix_transpose)(L)
63+
if get_xp(xp)(isdtype)(U.dtype, 'complex floating'):
64+
U = xp.conj(U)
65+
return U
6366
return L
6467

6568
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.

0 commit comments

Comments
 (0)