1
1
import logging
2
2
import warnings
3
+ from typing import Union
3
4
4
5
import numpy as np
5
6
import scipy .linalg
11
12
from aesara .tensor import basic as aet
12
13
from aesara .tensor import math as atm
13
14
from aesara .tensor .type import matrix , tensor , vector
15
+ from aesara .tensor .var import TensorVariable
14
16
15
17
16
18
logger = logging .getLogger (__name__ )
@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
259
261
return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
260
262
261
263
262
- class Solve (Op ):
263
- """
264
- Solve a system of linear equations.
265
-
266
- For on CPU and GPU.
267
- """
264
+ class SolveBase (Op ):
265
+ """Base class for `scipy.linalg` matrix equation solvers."""
268
266
269
267
__props__ = (
270
- "assume_a" ,
271
268
"lower" ,
272
- "check_finite" , # "transposed"
269
+ "check_finite" ,
273
270
)
274
271
275
272
def __init__ (
276
273
self ,
277
- assume_a = "gen" ,
278
274
lower = False ,
279
- check_finite = True , # transposed=False
275
+ check_finite = True ,
280
276
):
281
- if assume_a not in ("gen" , "sym" , "her" , "pos" ):
282
- raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
283
- self .assume_a = assume_a
284
277
self .lower = lower
285
278
self .check_finite = check_finite
286
- # self.transposed = transposed
287
279
288
- def __repr__ (self ):
289
- return "Solve{%s}" % str ( self . _props ())
280
+ def perform (self , node , inputs , outputs ):
281
+ pass
290
282
291
283
def make_node (self , A , b ):
292
284
A = as_tensor_variable (A )
293
285
b = as_tensor_variable (b )
294
- assert A .ndim == 2
295
- assert b .ndim in [1 , 2 ]
296
286
297
- # infer dtype by solving the most simple
298
- # case with (1, 1) matrices
287
+ if A .ndim != 2 :
288
+ raise ValueError (f"`A` must be a matrix; got { A .type } instead." )
289
+ if b .ndim not in [1 , 2 ]:
290
+ raise ValueError (f"`b` must be a matrix or a vector; got { b .type } instead." )
291
+
292
+ # Infer dtype by solving the most simple case with 1x1 matrices
299
293
o_dtype = scipy .linalg .solve (
300
294
np .eye (1 ).astype (A .dtype ), np .eye (1 ).astype (b .dtype )
301
295
).dtype
302
296
x = tensor (broadcastable = b .broadcastable , dtype = o_dtype )
303
297
return Apply (self , [A , b ], [x ])
304
298
305
- def perform (self , node , inputs , output_storage ):
306
- A , b = inputs
307
-
308
- if self .assume_a != "gen" :
309
- # if self.transposed:
310
- # if self.assume_a == "her":
311
- # trans = "C"
312
- # else:
313
- # trans = "T"
314
- # else:
315
- # trans = "N"
316
-
317
- rval = scipy .linalg .solve_triangular (
318
- A ,
319
- b ,
320
- lower = self .lower ,
321
- check_finite = self .check_finite ,
322
- # trans=trans
323
- )
324
- else :
325
- rval = scipy .linalg .solve (
326
- A ,
327
- b ,
328
- assume_a = self .assume_a ,
329
- lower = self .lower ,
330
- check_finite = self .check_finite ,
331
- # transposed=self.transposed,
332
- )
333
-
334
- output_storage [0 ][0 ] = rval
335
-
336
- # computes shape of x where x = inv(A) * b
337
299
def infer_shape (self , fgraph , node , shapes ):
338
300
Ashape , Bshape = shapes
339
301
rows = Ashape [1 ]
340
- if len (Bshape ) == 1 : # b is a Vector
302
+ if len (Bshape ) == 1 :
341
303
return [(rows ,)]
342
304
else :
343
- cols = Bshape [1 ] # b is a Matrix
305
+ cols = Bshape [1 ]
344
306
return [(rows , cols )]
345
307
346
308
def L_op (self , inputs , outputs , output_gradients ):
347
- r"""
348
- Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
309
+ r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
349
310
350
311
Symbolic expression for updates taken from [#]_.
351
312
@@ -364,31 +325,148 @@ def L_op(self, inputs, outputs, output_gradients):
364
325
# We need to return (dC/d[inv(A)], dC/db)
365
326
c_bar = output_gradients [0 ]
366
327
367
- trans_solve_op = Solve (
368
- assume_a = self .assume_a ,
369
- check_finite = self .check_finite ,
370
- lower = not self .lower ,
328
+ trans_solve_op = type (self )(
329
+ ** {
330
+ k : (not getattr (self , k ) if k == "lower" else getattr (self , k ))
331
+ for k in self .__props__
332
+ }
371
333
)
372
334
b_bar = trans_solve_op (A .T , c_bar )
373
335
# force outer product if vector second input
374
336
A_bar = - atm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
375
337
376
- if self .assume_a != "gen" :
377
- if self .lower :
378
- A_bar = aet .tril (A_bar )
379
- else :
380
- A_bar = aet .triu (A_bar )
381
-
382
338
return [A_bar , b_bar ]
383
339
340
+ def __repr__ (self ):
341
+ return f"{ type (self ).__name__ } { self ._props ()} "
342
+
343
+
344
+ class SolveTriangular (SolveBase ):
345
+ """Solve a system of linear equations."""
346
+
347
+ __props__ = (
348
+ "lower" ,
349
+ "trans" ,
350
+ "unit_diagonal" ,
351
+ "check_finite" ,
352
+ )
353
+
354
+ def __init__ (
355
+ self ,
356
+ trans = 0 ,
357
+ lower = False ,
358
+ unit_diagonal = False ,
359
+ check_finite = True ,
360
+ ):
361
+ super ().__init__ (lower = lower , check_finite = check_finite )
362
+ self .trans = trans
363
+ self .unit_diagonal = unit_diagonal
364
+
365
+ def perform (self , node , inputs , outputs ):
366
+ A , b = inputs
367
+ outputs [0 ][0 ] = scipy .linalg .solve_triangular (
368
+ A ,
369
+ b ,
370
+ lower = self .lower ,
371
+ trans = self .trans ,
372
+ unit_diagonal = self .unit_diagonal ,
373
+ check_finite = self .check_finite ,
374
+ )
375
+
376
+ def L_op (self , inputs , outputs , output_gradients ):
377
+ res = super ().L_op (inputs , outputs , output_gradients )
378
+
379
+ if self .lower :
380
+ res [0 ] = aet .tril (res [0 ])
381
+ else :
382
+ res [0 ] = aet .triu (res [0 ])
383
+
384
+ return res
385
+
386
+
387
+ solvetriangular = SolveTriangular ()
388
+
389
+
390
+ def solve_triangular (
391
+ a : TensorVariable ,
392
+ b : TensorVariable ,
393
+ trans : Union [int , str ] = 0 ,
394
+ lower : bool = False ,
395
+ unit_diagonal : bool = False ,
396
+ check_finite : bool = True ,
397
+ ) -> TensorVariable :
398
+ """Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
399
+
400
+ Parameters
401
+ ----------
402
+ a
403
+ Square input data
404
+ b
405
+ Input data for the right hand side.
406
+ lower : bool, optional
407
+ Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
408
+ trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
409
+ Type of system to solve:
410
+ trans system
411
+ 0 or 'N' a x = b
412
+ 1 or 'T' a^T x = b
413
+ 2 or 'C' a^H x = b
414
+ unit_diagonal: bool, optional
415
+ If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
416
+ check_finite : bool, optional
417
+ Whether to check that the input matrices contain only finite numbers.
418
+ Disabling may give a performance gain, but may result in problems
419
+ (crashes, non-termination) if the inputs do contain infinities or NaNs.
420
+ """
421
+ return SolveTriangular (
422
+ lower = lower ,
423
+ trans = trans ,
424
+ unit_diagonal = unit_diagonal ,
425
+ check_finite = check_finite ,
426
+ )(a , b )
427
+
428
+
429
+ class Solve (SolveBase ):
430
+ """
431
+ Solve a system of linear equations.
432
+
433
+ For on CPU and GPU.
434
+ """
435
+
436
+ __props__ = (
437
+ "assume_a" ,
438
+ "lower" ,
439
+ "check_finite" ,
440
+ )
441
+
442
+ def __init__ (
443
+ self ,
444
+ assume_a = "gen" ,
445
+ lower = False ,
446
+ check_finite = True ,
447
+ ):
448
+ if assume_a not in ("gen" , "sym" , "her" , "pos" ):
449
+ raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
450
+
451
+ super ().__init__ (lower = lower , check_finite = check_finite )
452
+ self .assume_a = assume_a
453
+
454
+ def perform (self , node , inputs , outputs ):
455
+ a , b = inputs
456
+ outputs [0 ][0 ] = scipy .linalg .solve (
457
+ a = a ,
458
+ b = b ,
459
+ lower = self .lower ,
460
+ check_finite = self .check_finite ,
461
+ assume_a = self .assume_a ,
462
+ )
463
+
384
464
385
465
solve = Solve ()
386
466
387
467
388
468
def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
389
- """
390
- Solves the linear equation set ``a * x = b`` for the unknown ``x``
391
- for square ``a`` matrix.
469
+ """Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
392
470
393
471
If the data matrix is known to be a particular type then supplying the
394
472
corresponding string to ``assume_a`` key chooses the dedicated solver.
@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
432
510
433
511
434
512
# TODO: These are deprecated; emit a warning
435
- solve_lower_triangular = Solve ( assume_a = "sym" , lower = True )
436
- solve_upper_triangular = Solve ( assume_a = "sym" , lower = False )
513
+ solve_lower_triangular = SolveTriangular ( lower = True )
514
+ solve_upper_triangular = SolveTriangular ( lower = False )
437
515
solve_symmetric = Solve (assume_a = "sym" )
438
516
439
517
# TODO: Optimizations to replace multiplication by matrix inverse
0 commit comments