1
+ import functools
2
+ import math
1
3
from typing import Sequence
2
4
3
5
import torch
@@ -26,23 +28,29 @@ def _atleast_float_2(a, b):
26
28
return a , b
27
29
28
30
29
- def _prod (iterable ):
30
- p = 1.0
31
- for x in iterable :
32
- p *= x
33
- return p
31
+ def linalg_errors (func ):
32
+ @functools .wraps (func )
33
+ def wrapped (* args , ** kwds ):
34
+ try :
35
+ return func (* args , ** kwds )
36
+ except torch ._C ._LinAlgError as e :
37
+ raise LinAlgError (* e .args )
38
+
39
+ return wrapped
34
40
35
41
36
42
# ### Matrix and vector products ###
37
43
38
44
39
45
@normalizer
46
+ @linalg_errors
40
47
def matrix_power (a : ArrayLike , n ):
41
48
a = _atleat_float_1 (a )
42
49
return torch .linalg .matrix_power (a , n )
43
50
44
51
45
52
@normalizer
53
+ @linalg_errors
46
54
def multi_dot (inputs : Sequence [ArrayLike ], * , out = None ):
47
55
return torch .linalg .multi_dot (inputs )
48
56
@@ -51,41 +59,46 @@ def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
51
59
52
60
53
61
@normalizer
62
+ @linalg_errors
54
63
def solve (a : ArrayLike , b : ArrayLike ):
55
64
a , b = _atleast_float_2 (a , b )
56
65
return torch .linalg .solve (a , b )
57
66
58
67
59
68
@normalizer
69
+ @linalg_errors
60
70
def lstsq (a : ArrayLike , b : ArrayLike , rcond = None ):
61
71
a , b = _atleast_float_2 (a , b )
62
72
# NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
63
- return torch .linalg .lstsq (a , b , rcond = rcond , driver = "gelsd" )
73
+ # on CUDA, only `gels` is available though, so use it instead
74
+ driver = "gels" if a .is_cuda or b .is_cuda else "gelsd"
75
+ return torch .linalg .lstsq (a , b , rcond = rcond , driver = driver )
64
76
65
77
66
78
@normalizer
79
+ @linalg_errors
67
80
def inv (a : ArrayLike ):
68
81
a = _atleast_float_1 (a )
69
- try :
70
- result = torch .linalg .inv (a )
71
- except torch ._C ._LinAlgError as e :
72
- raise LinAlgError (* e .args )
82
+ result = torch .linalg .inv (a )
73
83
return result
74
84
75
85
76
86
@normalizer
87
+ @linalg_errors
77
88
def pinv (a : ArrayLike , rcond = 1e-15 , hermitian = False ):
78
89
a = _atleast_float_1 (a )
79
90
return torch .linalg .pinv (a , rtol = rcond , hermitian = hermitian )
80
91
81
92
82
93
@normalizer
94
+ @linalg_errors
83
95
def tensorsolve (a : ArrayLike , b : ArrayLike , axes = None ):
84
96
a , b = _atleast_float_2 (a , b )
85
97
return torch .linalg .tensorsolve (a , b , dims = axes )
86
98
87
99
88
100
@normalizer
101
+ @linalg_errors
89
102
def tensorinv (a : ArrayLike , ind = 2 ):
90
103
a = _atleast_float_1 (a )
91
104
return torch .linalg .tensorinv (a , ind = ind )
@@ -95,50 +108,44 @@ def tensorinv(a: ArrayLike, ind=2):
95
108
96
109
97
110
@normalizer
111
+ @linalg_errors
98
112
def det (a : ArrayLike ):
99
113
a = _atleast_float_1 (a )
100
114
return torch .linalg .det (a )
101
115
102
116
103
117
@normalizer
118
+ @linalg_errors
104
119
def slogdet (a : ArrayLike ):
105
120
a = _atleast_float_1 (a )
106
121
return torch .linalg .slogdet (a )
107
122
108
123
109
124
@normalizer
125
+ @linalg_errors
110
126
def cond (x : ArrayLike , p = None ):
111
127
x = _atleast_float_1 (x )
112
128
113
129
# check if empty
114
130
# cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
115
- if x .numel () == 0 and _prod (x .shape [- 2 :]) == 0 :
131
+ if x .numel () == 0 and math . prod (x .shape [- 2 :]) == 0 :
116
132
raise LinAlgError ("cond is not defined on empty arrays" )
117
133
118
134
result = torch .linalg .cond (x , p = p )
119
135
120
- # Convert nans to infs unless the original array had nan entries
136
+ # Convert nans to infs (numpy does it in a data-dependent way, depending on
137
+ # whether the input array has nans or not)
121
138
# XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
122
- # Do we want it? .any() synchronizes
123
- """
124
- nan_mask = torch.isnan(result)
125
- if nan_mask.any():
126
- nan_mask &= ~torch.isnan(x).any(axis=(-2, -1)) # XXX: any() w/ tuple axes
127
- if result.ndim > 0:
128
- result[nan_mask] = torch.inf
129
- elif nan_mask:
130
- result[()] = torch.inf
131
- """
132
-
133
- return result
139
+ return torch .where (torch .isnan (result ), float ("inf" ), result )
134
140
135
141
136
142
@normalizer
143
+ @linalg_errors
137
144
def matrix_rank (a : ArrayLike , tol = None , hermitian = False ):
138
145
a = _atleast_float_1 (a )
139
146
140
147
if a .ndim < 2 :
141
- return int (not (a == 0 ).all ())
148
+ return int ((a != 0 ).any ())
142
149
143
150
if tol is None :
144
151
# follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
@@ -150,6 +157,7 @@ def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
150
157
151
158
152
159
@normalizer
160
+ @linalg_errors
153
161
def norm (x : ArrayLike , ord = None , axis = None , keepdims = False ):
154
162
x = _atleast_float_1 (x )
155
163
result = torch .linalg .norm (x , ord = ord , dim = axis )
@@ -162,12 +170,14 @@ def norm(x: ArrayLike, ord=None, axis=None, keepdims=False):
162
170
163
171
164
172
@normalizer
173
+ @linalg_errors
165
174
def cholesky (a : ArrayLike ):
166
175
a = _atleast_float_1 (a )
167
176
return torch .linalg .cholesky (a )
168
177
169
178
170
179
@normalizer
180
+ @linalg_errors
171
181
def qr (a : ArrayLike , mode = "reduced" ):
172
182
a = _atleast_float_1 (a )
173
183
result = torch .linalg .qr (a , mode = mode )
@@ -178,19 +188,22 @@ def qr(a: ArrayLike, mode="reduced"):
178
188
179
189
180
190
@normalizer
191
+ @linalg_errors
181
192
def svd (a : ArrayLike , full_matrices = True , compute_uv = True , hermitian = False ):
182
193
a = _atleast_float_1 (a )
194
+ if not compute_uv :
195
+ return torch .linalg .svdvals (a )
196
+
183
197
# NB: ignore the hermitian= argument (no pytorch equivalent)
184
198
result = torch .linalg .svd (a , full_matrices = full_matrices )
185
- if not compute_uv :
186
- result = result .S
187
199
return result
188
200
189
201
190
202
# ### Eigenvalues and eigenvectors ###
191
203
192
204
193
205
@normalizer
206
+ @linalg_errors
194
207
def eig (a : ArrayLike ):
195
208
a = _atleast_float_1 (a )
196
209
w , vt = torch .linalg .eig (a )
@@ -203,12 +216,14 @@ def eig(a: ArrayLike):
203
216
204
217
205
218
@normalizer
219
+ @linalg_errors
206
220
def eigh (a : ArrayLike , UPLO = "L" ):
207
221
a = _atleast_float_1 (a )
208
222
return torch .linalg .eigh (a , UPLO = UPLO )
209
223
210
224
211
225
@normalizer
226
+ @linalg_errors
212
227
def eigvals (a : ArrayLike ):
213
228
a = _atleast_float_1 (a )
214
229
result = torch .linalg .eigvals (a )
@@ -219,6 +234,7 @@ def eigvals(a: ArrayLike):
219
234
220
235
221
236
@normalizer
237
+ @linalg_errors
222
238
def eigvalsh (a : ArrayLike , UPLO = "L" ):
223
239
a = _atleast_float_1 (a )
224
240
return torch .linalg .eigvalsh (a , UPLO = UPLO )
0 commit comments