Skip to content

Commit f50cb75

Browse files
committed
MAINT: address review comments; convert linalg exceptions
1 parent 015bc44 commit f50cb75

File tree

2 files changed

+43
-29
lines changed

2 files changed

+43
-29
lines changed

torch_np/linalg.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import functools
2+
import math
13
from typing import Sequence
24

35
import torch
@@ -26,23 +28,29 @@ def _atleast_float_2(a, b):
2628
return a, b
2729

2830

29-
def _prod(iterable):
30-
p = 1.0
31-
for x in iterable:
32-
p *= x
33-
return p
31+
def linalg_errors(func):
32+
@functools.wraps(func)
33+
def wrapped(*args, **kwds):
34+
try:
35+
return func(*args, **kwds)
36+
except torch._C._LinAlgError as e:
37+
raise LinAlgError(*e.args)
38+
39+
return wrapped
3440

3541

3642
# ### Matrix and vector products ###
3743

3844

3945
@normalizer
46+
@linalg_errors
4047
def matrix_power(a: ArrayLike, n):
4148
a = _atleat_float_1(a)
4249
return torch.linalg.matrix_power(a, n)
4350

4451

4552
@normalizer
53+
@linalg_errors
4654
def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
4755
return torch.linalg.multi_dot(inputs)
4856

@@ -51,41 +59,46 @@ def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
5159

5260

5361
@normalizer
62+
@linalg_errors
5463
def solve(a: ArrayLike, b: ArrayLike):
5564
a, b = _atleast_float_2(a, b)
5665
return torch.linalg.solve(a, b)
5766

5867

5968
@normalizer
69+
@linalg_errors
6070
def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
6171
a, b = _atleast_float_2(a, b)
6272
# NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
63-
return torch.linalg.lstsq(a, b, rcond=rcond, driver="gelsd")
73+
# on CUDA, only `gels` is available though, so use it instead
74+
driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
75+
return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
6476

6577

6678
@normalizer
79+
@linalg_errors
6780
def inv(a: ArrayLike):
6881
a = _atleast_float_1(a)
69-
try:
70-
result = torch.linalg.inv(a)
71-
except torch._C._LinAlgError as e:
72-
raise LinAlgError(*e.args)
82+
result = torch.linalg.inv(a)
7383
return result
7484

7585

7686
@normalizer
87+
@linalg_errors
7788
def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
7889
a = _atleast_float_1(a)
7990
return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
8091

8192

8293
@normalizer
94+
@linalg_errors
8395
def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
8496
a, b = _atleast_float_2(a, b)
8597
return torch.linalg.tensorsolve(a, b, dims=axes)
8698

8799

88100
@normalizer
101+
@linalg_errors
89102
def tensorinv(a: ArrayLike, ind=2):
90103
a = _atleast_float_1(a)
91104
return torch.linalg.tensorinv(a, ind=ind)
@@ -95,50 +108,44 @@ def tensorinv(a: ArrayLike, ind=2):
95108

96109

97110
@normalizer
111+
@linalg_errors
98112
def det(a: ArrayLike):
99113
a = _atleast_float_1(a)
100114
return torch.linalg.det(a)
101115

102116

103117
@normalizer
118+
@linalg_errors
104119
def slogdet(a: ArrayLike):
105120
a = _atleast_float_1(a)
106121
return torch.linalg.slogdet(a)
107122

108123

109124
@normalizer
125+
@linalg_errors
110126
def cond(x: ArrayLike, p=None):
111127
x = _atleast_float_1(x)
112128

113129
# check if empty
114130
# cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
115-
if x.numel() == 0 and _prod(x.shape[-2:]) == 0:
131+
if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
116132
raise LinAlgError("cond is not defined on empty arrays")
117133

118134
result = torch.linalg.cond(x, p=p)
119135

120-
# Convert nans to infs unless the original array had nan entries
136+
# Convert nans to infs (numpy does it in a data-dependent way, depending on
137+
# whether the input array has nans or not)
121138
# XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
122-
# Do we want it? .any() synchronizes
123-
"""
124-
nan_mask = torch.isnan(result)
125-
if nan_mask.any():
126-
nan_mask &= ~torch.isnan(x).any(axis=(-2, -1)) # XXX: any() w/ tuple axes
127-
if result.ndim > 0:
128-
result[nan_mask] = torch.inf
129-
elif nan_mask:
130-
result[()] = torch.inf
131-
"""
132-
133-
return result
139+
return torch.where(torch.isnan(result), float("inf"), result)
134140

135141

136142
@normalizer
143+
@linalg_errors
137144
def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
138145
a = _atleast_float_1(a)
139146

140147
if a.ndim < 2:
141-
return int(not (a == 0).all())
148+
return int((a != 0).any())
142149

143150
if tol is None:
144151
# follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
@@ -150,6 +157,7 @@ def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
150157

151158

152159
@normalizer
160+
@linalg_errors
153161
def norm(x: ArrayLike, ord=None, axis=None, keepdims=False):
154162
x = _atleast_float_1(x)
155163
result = torch.linalg.norm(x, ord=ord, dim=axis)
@@ -162,12 +170,14 @@ def norm(x: ArrayLike, ord=None, axis=None, keepdims=False):
162170

163171

164172
@normalizer
173+
@linalg_errors
165174
def cholesky(a: ArrayLike):
166175
a = _atleast_float_1(a)
167176
return torch.linalg.cholesky(a)
168177

169178

170179
@normalizer
180+
@linalg_errors
171181
def qr(a: ArrayLike, mode="reduced"):
172182
a = _atleast_float_1(a)
173183
result = torch.linalg.qr(a, mode=mode)
@@ -178,19 +188,22 @@ def qr(a: ArrayLike, mode="reduced"):
178188

179189

180190
@normalizer
191+
@linalg_errors
181192
def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
182193
a = _atleast_float_1(a)
194+
if not compute_uv:
195+
return torch.linalg.svdvals(a)
196+
183197
# NB: ignore the hermitian= argument (no pytorch equivalent)
184198
result = torch.linalg.svd(a, full_matrices=full_matrices)
185-
if not compute_uv:
186-
result = result.S
187199
return result
188200

189201

190202
# ### Eigenvalues and eigenvectors ###
191203

192204

193205
@normalizer
206+
@linalg_errors
194207
def eig(a: ArrayLike):
195208
a = _atleast_float_1(a)
196209
w, vt = torch.linalg.eig(a)
@@ -203,12 +216,14 @@ def eig(a: ArrayLike):
203216

204217

205218
@normalizer
219+
@linalg_errors
206220
def eigh(a: ArrayLike, UPLO="L"):
207221
a = _atleast_float_1(a)
208222
return torch.linalg.eigh(a, UPLO=UPLO)
209223

210224

211225
@normalizer
226+
@linalg_errors
212227
def eigvals(a: ArrayLike):
213228
a = _atleast_float_1(a)
214229
result = torch.linalg.eigvals(a)
@@ -219,6 +234,7 @@ def eigvals(a: ArrayLike):
219234

220235

221236
@normalizer
237+
@linalg_errors
222238
def eigvalsh(a: ArrayLike, UPLO="L"):
223239
a = _atleast_float_1(a)
224240
return torch.linalg.eigvalsh(a, UPLO=UPLO)

torch_np/tests/numpy_tests/linalg/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,6 @@ def test_basic_nonsvd(self):
705705
assert_almost_equal(linalg.cond(A, -1), 0.5)
706706
assert_almost_equal(linalg.cond(A, 'fro'), np.sqrt(265 / 12))
707707

708-
@pytest.mark.xfail(reason="nan -> inf, as numpy does")
709708
def test_singular(self):
710709
# Singular matrices have infinite condition number for
711710
# positive norms, and negative norms shouldn't raise
@@ -747,7 +746,6 @@ def test_nan(self):
747746
assert_(not np.isnan(c[0]))
748747
assert_(not np.isnan(c[2]))
749748

750-
@pytest.mark.xfail(reason="nan -> inf, as numpy does")
751749
def test_stacked_singular(self):
752750
# Check behavior when only some of the stacked matrices are
753751
# singular

0 commit comments

Comments
 (0)