5
5
import scipy .linalg
6
6
7
7
import aesara .tensor
8
- import aesara .tensor .basic as aet
9
- import aesara .tensor .math as tm
10
8
from aesara .graph .basic import Apply
11
9
from aesara .graph .op import Op
12
10
from aesara .tensor import as_tensor_variable
11
+ from aesara .tensor import basic as aet
12
+ from aesara .tensor import math as atm
13
13
from aesara .tensor .type import matrix , tensor , vector
14
14
15
15
16
16
logger = logging .getLogger (__name__ )
17
17
18
- MATRIX_STRUCTURES = (
19
- "general" ,
20
- "symmetric" ,
21
- "lower_triangular" ,
22
- "upper_triangular" ,
23
- "hermitian" ,
24
- "banded" ,
25
- "diagonal" ,
26
- "toeplitz" ,
27
- )
28
-
29
18
30
19
class Cholesky (Op ):
31
20
"""
@@ -95,7 +84,7 @@ def L_op(self, inputs, outputs, gradients):
95
84
# Replace the cholesky decomposition with 1 if there are nans
96
85
# or solve_upper_triangular will throw a ValueError.
97
86
if self .on_error == "nan" :
98
- ok = ~ tm .any (tm .isnan (chol_x ))
87
+ ok = ~ atm .any (atm .isnan (chol_x ))
99
88
chol_x = aet .switch (ok , chol_x , 1 )
100
89
dz = aet .switch (ok , dz , 1 )
101
90
@@ -206,17 +195,24 @@ class Solve(Op):
206
195
For on CPU and GPU.
207
196
"""
208
197
209
- __props__ = ("A_structure" , "lower" , "overwrite_A" , "overwrite_b" )
198
+ __props__ = (
199
+ "assume_a" ,
200
+ "lower" ,
201
+ "check_finite" , # "transposed"
202
+ )
210
203
211
204
def __init__ (
212
- self , A_structure = "general" , lower = False , overwrite_A = False , overwrite_b = False
205
+ self ,
206
+ assume_a = "gen" ,
207
+ lower = False ,
208
+ check_finite = True , # transposed=False
213
209
):
214
- if A_structure not in MATRIX_STRUCTURES :
215
- raise ValueError ("Invalid matrix structure argument" , A_structure )
216
- self .A_structure = A_structure
210
+ if assume_a not in ( "gen" , "sym" , "her" , "pos" ) :
211
+ raise ValueError (f" { assume_a } is not a recognized matrix structure" )
212
+ self .assume_a = assume_a
217
213
self .lower = lower
218
- self .overwrite_A = overwrite_A
219
- self .overwrite_b = overwrite_b
214
+ self .check_finite = check_finite
215
+ # self.transposed = transposed
220
216
221
217
def __repr__ (self ):
222
218
return "Solve{%s}" % str (self ._props ())
@@ -237,12 +233,33 @@ def make_node(self, A, b):
237
233
238
234
def perform (self , node , inputs , output_storage ):
239
235
A , b = inputs
240
- if self .A_structure == "lower_triangular" :
241
- rval = scipy .linalg .solve_triangular (A , b , lower = True )
242
- elif self .A_structure == "upper_triangular" :
243
- rval = scipy .linalg .solve_triangular (A , b , lower = False )
236
+
237
+ if self .assume_a != "gen" :
238
+ # if self.transposed:
239
+ # if self.assume_a == "her":
240
+ # trans = "C"
241
+ # else:
242
+ # trans = "T"
243
+ # else:
244
+ # trans = "N"
245
+
246
+ rval = scipy .linalg .solve_triangular (
247
+ A ,
248
+ b ,
249
+ lower = self .lower ,
250
+ check_finite = self .check_finite ,
251
+ # trans=trans
252
+ )
244
253
else :
245
- rval = scipy .linalg .solve (A , b )
254
+ rval = scipy .linalg .solve (
255
+ A ,
256
+ b ,
257
+ assume_a = self .assume_a ,
258
+ lower = self .lower ,
259
+ check_finite = self .check_finite ,
260
+ # transposed=self.transposed,
261
+ )
262
+
246
263
output_storage [0 ][0 ] = rval
247
264
248
265
# computes shape of x where x = inv(A) * b
@@ -257,7 +274,7 @@ def infer_shape(self, fgraph, node, shapes):
257
274
258
275
def L_op (self , inputs , outputs , output_gradients ):
259
276
r"""
260
- Reverse-mode gradient updates for matrix solve operation c = A \\\ b .
277
+ Reverse-mode gradient updates for matrix solve operation :math:` c = A^{-1} b` .
261
278
262
279
Symbolic expression for updates taken from [#]_.
263
280
@@ -269,53 +286,84 @@ def L_op(self, inputs, outputs, output_gradients):
269
286
270
287
"""
271
288
A , b = inputs
289
+
272
290
c = outputs [0 ]
291
+ # C is a scalar representing the entire graph
292
+ # `output_gradients` is (dC/dc,)
293
+ # We need to return (dC/d[inv(A)], dC/db)
273
294
c_bar = output_gradients [0 ]
274
- trans_map = {
275
- "lower_triangular" : "upper_triangular" ,
276
- "upper_triangular" : "lower_triangular" ,
277
- }
295
+
278
296
trans_solve_op = Solve (
279
- # update A_structure and lower to account for a transpose operation
280
- A_structure = trans_map . get ( self .A_structure , self . A_structure ) ,
297
+ assume_a = self . assume_a ,
298
+ check_finite = self .check_finite ,
281
299
lower = not self .lower ,
282
300
)
283
301
b_bar = trans_solve_op (A .T , c_bar )
284
302
# force outer product if vector second input
285
- A_bar = - tm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
286
- if self .A_structure == "lower_triangular" :
287
- A_bar = aet .tril (A_bar )
288
- elif self .A_structure == "upper_triangular" :
289
- A_bar = aet .triu (A_bar )
303
+ A_bar = - atm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
304
+
305
+ if self .assume_a != "gen" :
306
+ if self .lower :
307
+ A_bar = aet .tril (A_bar )
308
+ else :
309
+ A_bar = aet .triu (A_bar )
310
+
290
311
return [A_bar , b_bar ]
291
312
292
313
293
314
solve = Solve ()
294
- """
295
- Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
296
- ``b`` can be either a vector or a matrix.
297
-
298
- Parameters
299
- ----------
300
- a : `(M, M) symbolix matrix`
301
- A square matrix
302
- b : `(M,) or (M, N) symbolic vector or matrix`
303
- Right hand side matrix in ``a x = b``
304
-
305
-
306
- Returns
307
- -------
308
- x : `(M, ) or (M, N) symbolic vector or matrix`
309
- x will have the same shape as b
310
- """
311
- # lower and upper triangular solves
312
- solve_lower_triangular = Solve (A_structure = "lower_triangular" , lower = True )
313
- """Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is lower triangular."""
314
- solve_upper_triangular = Solve (A_structure = "upper_triangular" , lower = False )
315
- """Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is upper triangular."""
316
- # symmetric solves
317
- solve_symmetric = Solve (A_structure = "symmetric" )
318
- """Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is symmetric."""
315
+
316
+
317
+ def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
318
+ """
319
+ Solves the linear equation set ``a * x = b`` for the unknown ``x``
320
+ for square ``a`` matrix.
321
+
322
+ If the data matrix is known to be a particular type then supplying the
323
+ corresponding string to ``assume_a`` key chooses the dedicated solver.
324
+ The available options are
325
+
326
+ =================== ========
327
+ generic matrix 'gen'
328
+ symmetric 'sym'
329
+ hermitian 'her'
330
+ positive definite 'pos'
331
+ =================== ========
332
+
333
+ If omitted, ``'gen'`` is the default structure.
334
+
335
+ The datatype of the arrays define which solver is called regardless
336
+ of the values. In other words, even when the complex array entries have
337
+ precisely zero imaginary parts, the complex solver will be called based
338
+ on the data type of the array.
339
+
340
+ Parameters
341
+ ----------
342
+ a : (N, N) array_like
343
+ Square input data
344
+ b : (N, NRHS) array_like
345
+ Input data for the right hand side.
346
+ lower : bool, optional
347
+ If True, only the data contained in the lower triangle of `a`. Default
348
+ is to use upper triangle. (ignored for ``'gen'``)
349
+ check_finite : bool, optional
350
+ Whether to check that the input matrices contain only finite numbers.
351
+ Disabling may give a performance gain, but may result in problems
352
+ (crashes, non-termination) if the inputs do contain infinities or NaNs.
353
+ assume_a : str, optional
354
+ Valid entries are explained above.
355
+ """
356
+ return Solve (
357
+ lower = lower ,
358
+ check_finite = check_finite ,
359
+ assume_a = assume_a ,
360
+ )(a , b )
361
+
362
+
363
+ # TODO: These are deprecated; emit a warning
364
+ solve_lower_triangular = Solve (assume_a = "sym" , lower = True )
365
+ solve_upper_triangular = Solve (assume_a = "sym" , lower = False )
366
+ solve_symmetric = Solve (assume_a = "sym" )
319
367
320
368
# TODO: Optimizations to replace multiplication by matrix inverse
321
369
# with solve() Op (still unwritten)
@@ -456,7 +504,7 @@ def kron(a, b):
456
504
"kron: inputs dimensions must sum to 3 or more. "
457
505
f"You passed { int (a .ndim )} and { int (b .ndim )} ."
458
506
)
459
- o = tm .outer (a , b )
507
+ o = atm .outer (a , b )
460
508
o = o .reshape (aet .concatenate ((a .shape , b .shape )), a .ndim + b .ndim )
461
509
shf = o .dimshuffle (0 , 2 , 1 , * list (range (3 , o .ndim )))
462
510
if shf .ndim == 3 :
0 commit comments