Skip to content

Commit fece5e0

Browse files
committed
Fix some issues with the linalg wrapping
1 parent 360ea18 commit fece5e0

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

array_api_compat/common/_linalg.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def svd(x: ndarray, /, xp, *, full_matrices: bool = True) -> SVDResult:
5959
def cholesky(x: ndarray, /, xp, *, upper: bool = False) -> ndarray:
6060
L = xp.linalg.cholesky(x)
6161
if upper:
62-
return matrix_transpose(L)
62+
return get_xp(xp)(matrix_transpose)(L)
6363
return L
6464

6565
# The rtol keyword argument of matrix_rank() and pinv() is new from NumPy.
@@ -158,7 +158,8 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0) -> ndarray:
158158
def trace(x: ndarray, /, xp, *, offset: int = 0) -> ndarray:
159159
return xp.asarray(xp.trace(x, offset=offset, axis1=-2, axis2=-1))
160160

161-
__all__ = ['cross', 'diagonal', 'matmul', 'cholesky', 'matrix_rank', 'pinv',
162-
'matrix_norm', 'matrix_transpose', 'outer', 'svdvals',
163-
'tensordot', 'trace', 'vecdot', 'vector_norm', 'EighResult',
164-
'QRResult', 'SlogdetResult', 'SVDResult']
161+
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
162+
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
163+
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
164+
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
165+
'trace']

array_api_compat/cupy/linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,27 @@
1414
import cupy as cp
1515

1616
cross = get_xp(cp)(_linalg.cross)
17-
diagonal = get_xp(cp)(_linalg.diagonal)
1817
matmul = get_xp(cp)(_linalg.matmul)
18+
outer = get_xp(cp)(_linalg.outer)
19+
tensordot = get_xp(cp)(_linalg.tensordot)
20+
EighResult = _linalg.EighResult
21+
QRResult = _linalg.QRResult
22+
SlogdetResult = _linalg.SlogdetResult
23+
SVDResult = _linalg.SVDResult
24+
eigh = get_xp(cp)(_linalg.eigh)
25+
qr = get_xp(cp)(_linalg.qr)
26+
slogdet = get_xp(cp)(_linalg.slogdet)
27+
svd = get_xp(cp)(_linalg.svd)
1928
cholesky = get_xp(cp)(_linalg.cholesky)
2029
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
2130
pinv = get_xp(cp)(_linalg.pinv)
2231
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
2332
matrix_transpose = get_xp(cp)(_linalg.matrix_transpose)
24-
outer = get_xp(cp)(_linalg.outer)
2533
svdvals = get_xp(cp)(_linalg.svdvals)
26-
tensordot = get_xp(cp)(_linalg.tensordot)
27-
trace = get_xp(cp)(_linalg.trace)
2834
vecdot = get_xp(cp)(_linalg.vecdot)
2935
vector_norm = get_xp(cp)(_linalg.vector_norm)
36+
diagonal = get_xp(cp)(_linalg.diagonal)
37+
trace = get_xp(cp)(_linalg.trace)
3038

3139
__all__ = linalg_all + _linalg.__all__
3240

array_api_compat/numpy/linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@
77
import numpy as np
88

99
cross = get_xp(np)(_linalg.cross)
10-
diagonal = get_xp(np)(_linalg.diagonal)
1110
matmul = get_xp(np)(_linalg.matmul)
11+
outer = get_xp(np)(_linalg.outer)
12+
tensordot = get_xp(np)(_linalg.tensordot)
13+
EighResult = _linalg.EighResult
14+
QRResult = _linalg.QRResult
15+
SlogdetResult = _linalg.SlogdetResult
16+
SVDResult = _linalg.SVDResult
17+
eigh = get_xp(np)(_linalg.eigh)
18+
qr = get_xp(np)(_linalg.qr)
19+
slogdet = get_xp(np)(_linalg.slogdet)
20+
svd = get_xp(np)(_linalg.svd)
1221
cholesky = get_xp(np)(_linalg.cholesky)
1322
matrix_rank = get_xp(np)(_linalg.matrix_rank)
1423
pinv = get_xp(np)(_linalg.pinv)
1524
matrix_norm = get_xp(np)(_linalg.matrix_norm)
1625
matrix_transpose = get_xp(np)(_linalg.matrix_transpose)
17-
outer = get_xp(np)(_linalg.outer)
1826
svdvals = get_xp(np)(_linalg.svdvals)
19-
tensordot = get_xp(np)(_linalg.tensordot)
20-
trace = get_xp(np)(_linalg.trace)
2127
vecdot = get_xp(np)(_linalg.vecdot)
2228
vector_norm = get_xp(np)(_linalg.vector_norm)
29+
diagonal = get_xp(np)(_linalg.diagonal)
30+
trace = get_xp(np)(_linalg.trace)
2331

2432
__all__ = linalg_all + _linalg.__all__
2533

0 commit comments

Comments
 (0)