Skip to content

Commit 79961a6

Browse files
fshartbrandonwillard
authored andcommitted
Add a SolveTriangular Op
`Solve` has also been changed to match SciPy.
1 parent 6fce270 commit 79961a6

File tree

3 files changed

+315
-169
lines changed

3 files changed

+315
-169
lines changed

aesara/tensor/slinalg.py

Lines changed: 149 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import warnings
3+
from typing import Union
34

45
import numpy as np
56
import scipy.linalg
@@ -11,6 +12,7 @@
1112
from aesara.tensor import basic as aet
1213
from aesara.tensor import math as atm
1314
from aesara.tensor.type import matrix, tensor, vector
15+
from aesara.tensor.var import TensorVariable
1416

1517

1618
logger = logging.getLogger(__name__)
@@ -259,93 +261,52 @@ def cho_solve(c_and_lower, b, check_finite=True):
259261
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
260262

261263

262-
class Solve(Op):
263-
"""
264-
Solve a system of linear equations.
265-
266-
For on CPU and GPU.
267-
"""
264+
class SolveBase(Op):
265+
"""Base class for `scipy.linalg` matrix equation solvers."""
268266

269267
__props__ = (
270-
"assume_a",
271268
"lower",
272-
"check_finite", # "transposed"
269+
"check_finite",
273270
)
274271

275272
def __init__(
276273
self,
277-
assume_a="gen",
278274
lower=False,
279-
check_finite=True, # transposed=False
275+
check_finite=True,
280276
):
281-
if assume_a not in ("gen", "sym", "her", "pos"):
282-
raise ValueError(f"{assume_a} is not a recognized matrix structure")
283-
self.assume_a = assume_a
284277
self.lower = lower
285278
self.check_finite = check_finite
286-
# self.transposed = transposed
287279

288-
def __repr__(self):
289-
return "Solve{%s}" % str(self._props())
280+
def perform(self, node, inputs, outputs):
281+
pass
290282

291283
def make_node(self, A, b):
292284
A = as_tensor_variable(A)
293285
b = as_tensor_variable(b)
294-
assert A.ndim == 2
295-
assert b.ndim in [1, 2]
296286

297-
# infer dtype by solving the most simple
298-
# case with (1, 1) matrices
287+
if A.ndim != 2:
288+
raise ValueError(f"`A` must be a matrix; got {A.type} instead.")
289+
if b.ndim not in [1, 2]:
290+
raise ValueError(f"`b` must be a matrix or a vector; got {b.type} instead.")
291+
292+
# Infer dtype by solving the most simple case with 1x1 matrices
299293
o_dtype = scipy.linalg.solve(
300294
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
301295
).dtype
302296
x = tensor(broadcastable=b.broadcastable, dtype=o_dtype)
303297
return Apply(self, [A, b], [x])
304298

305-
def perform(self, node, inputs, output_storage):
306-
A, b = inputs
307-
308-
if self.assume_a != "gen":
309-
# if self.transposed:
310-
# if self.assume_a == "her":
311-
# trans = "C"
312-
# else:
313-
# trans = "T"
314-
# else:
315-
# trans = "N"
316-
317-
rval = scipy.linalg.solve_triangular(
318-
A,
319-
b,
320-
lower=self.lower,
321-
check_finite=self.check_finite,
322-
# trans=trans
323-
)
324-
else:
325-
rval = scipy.linalg.solve(
326-
A,
327-
b,
328-
assume_a=self.assume_a,
329-
lower=self.lower,
330-
check_finite=self.check_finite,
331-
# transposed=self.transposed,
332-
)
333-
334-
output_storage[0][0] = rval
335-
336-
# computes shape of x where x = inv(A) * b
337299
def infer_shape(self, fgraph, node, shapes):
338300
Ashape, Bshape = shapes
339301
rows = Ashape[1]
340-
if len(Bshape) == 1: # b is a Vector
302+
if len(Bshape) == 1:
341303
return [(rows,)]
342304
else:
343-
cols = Bshape[1] # b is a Matrix
305+
cols = Bshape[1]
344306
return [(rows, cols)]
345307

346308
def L_op(self, inputs, outputs, output_gradients):
347-
r"""
348-
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
309+
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
349310
350311
Symbolic expression for updates taken from [#]_.
351312
@@ -364,31 +325,148 @@ def L_op(self, inputs, outputs, output_gradients):
364325
# We need to return (dC/d[inv(A)], dC/db)
365326
c_bar = output_gradients[0]
366327

367-
trans_solve_op = Solve(
368-
assume_a=self.assume_a,
369-
check_finite=self.check_finite,
370-
lower=not self.lower,
328+
trans_solve_op = type(self)(
329+
**{
330+
k: (not getattr(self, k) if k == "lower" else getattr(self, k))
331+
for k in self.__props__
332+
}
371333
)
372334
b_bar = trans_solve_op(A.T, c_bar)
373335
# force outer product if vector second input
374336
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
375337

376-
if self.assume_a != "gen":
377-
if self.lower:
378-
A_bar = aet.tril(A_bar)
379-
else:
380-
A_bar = aet.triu(A_bar)
381-
382338
return [A_bar, b_bar]
383339

340+
def __repr__(self):
341+
return f"{type(self).__name__}{self._props()}"
342+
343+
344+
class SolveTriangular(SolveBase):
345+
"""Solve a system of linear equations."""
346+
347+
__props__ = (
348+
"lower",
349+
"trans",
350+
"unit_diagonal",
351+
"check_finite",
352+
)
353+
354+
def __init__(
355+
self,
356+
trans=0,
357+
lower=False,
358+
unit_diagonal=False,
359+
check_finite=True,
360+
):
361+
super().__init__(lower=lower, check_finite=check_finite)
362+
self.trans = trans
363+
self.unit_diagonal = unit_diagonal
364+
365+
def perform(self, node, inputs, outputs):
366+
A, b = inputs
367+
outputs[0][0] = scipy.linalg.solve_triangular(
368+
A,
369+
b,
370+
lower=self.lower,
371+
trans=self.trans,
372+
unit_diagonal=self.unit_diagonal,
373+
check_finite=self.check_finite,
374+
)
375+
376+
def L_op(self, inputs, outputs, output_gradients):
377+
res = super().L_op(inputs, outputs, output_gradients)
378+
379+
if self.lower:
380+
res[0] = aet.tril(res[0])
381+
else:
382+
res[0] = aet.triu(res[0])
383+
384+
return res
385+
386+
387+
solvetriangular = SolveTriangular()
388+
389+
390+
def solve_triangular(
391+
a: TensorVariable,
392+
b: TensorVariable,
393+
trans: Union[int, str] = 0,
394+
lower: bool = False,
395+
unit_diagonal: bool = False,
396+
check_finite: bool = True,
397+
) -> TensorVariable:
398+
"""Solve the equation `a x = b` for `x`, assuming `a` is a triangular matrix.
399+
400+
Parameters
401+
----------
402+
a
403+
Square input data
404+
b
405+
Input data for the right hand side.
406+
lower : bool, optional
407+
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
408+
trans: {0, 1, 2, ‘N’, ‘T’, ‘C’}, optional
409+
Type of system to solve:
410+
trans system
411+
0 or 'N' a x = b
412+
1 or 'T' a^T x = b
413+
2 or 'C' a^H x = b
414+
unit_diagonal: bool, optional
415+
If True, diagonal elements of `a` are assumed to be 1 and will not be referenced.
416+
check_finite : bool, optional
417+
Whether to check that the input matrices contain only finite numbers.
418+
Disabling may give a performance gain, but may result in problems
419+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
420+
"""
421+
return SolveTriangular(
422+
lower=lower,
423+
trans=trans,
424+
unit_diagonal=unit_diagonal,
425+
check_finite=check_finite,
426+
)(a, b)
427+
428+
429+
class Solve(SolveBase):
430+
"""
431+
Solve a system of linear equations.
432+
433+
For on CPU and GPU.
434+
"""
435+
436+
__props__ = (
437+
"assume_a",
438+
"lower",
439+
"check_finite",
440+
)
441+
442+
def __init__(
443+
self,
444+
assume_a="gen",
445+
lower=False,
446+
check_finite=True,
447+
):
448+
if assume_a not in ("gen", "sym", "her", "pos"):
449+
raise ValueError(f"{assume_a} is not a recognized matrix structure")
450+
451+
super().__init__(lower=lower, check_finite=check_finite)
452+
self.assume_a = assume_a
453+
454+
def perform(self, node, inputs, outputs):
455+
a, b = inputs
456+
outputs[0][0] = scipy.linalg.solve(
457+
a=a,
458+
b=b,
459+
lower=self.lower,
460+
check_finite=self.check_finite,
461+
assume_a=self.assume_a,
462+
)
463+
384464

385465
solve = Solve()
386466

387467

388468
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
389-
"""
390-
Solves the linear equation set ``a * x = b`` for the unknown ``x``
391-
for square ``a`` matrix.
469+
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
392470
393471
If the data matrix is known to be a particular type then supplying the
394472
corresponding string to ``assume_a`` key chooses the dedicated solver.
@@ -432,8 +510,8 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True):
432510

433511

434512
# TODO: These are deprecated; emit a warning
435-
solve_lower_triangular = Solve(assume_a="sym", lower=True)
436-
solve_upper_triangular = Solve(assume_a="sym", lower=False)
513+
solve_lower_triangular = SolveTriangular(lower=True)
514+
solve_upper_triangular = SolveTriangular(lower=False)
437515
solve_symmetric = Solve(assume_a="sym")
438516

439517
# TODO: Optimizations to replace multiplication by matrix inverse

tests/link/test_numba.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2174,6 +2174,31 @@ def test_Cholesky(x, lower, exc):
21742174
"gen",
21752175
None,
21762176
),
2177+
],
2178+
)
2179+
def test_Solve(A, x, lower, exc):
2180+
g = slinalg.Solve(lower)(A, x)
2181+
2182+
if isinstance(g, list):
2183+
g_fg = FunctionGraph(outputs=g)
2184+
else:
2185+
g_fg = FunctionGraph(outputs=[g])
2186+
2187+
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
2188+
with cm:
2189+
compare_numba_and_py(
2190+
g_fg,
2191+
[
2192+
i.tag.test_value
2193+
for i in g_fg.inputs
2194+
if not isinstance(i, (SharedVariable, Constant))
2195+
],
2196+
)
2197+
2198+
2199+
@pytest.mark.parametrize(
2200+
"A, x, lower, exc",
2201+
[
21772202
(
21782203
set_test_value(
21792204
aet.dmatrix(),
@@ -2185,8 +2210,8 @@ def test_Cholesky(x, lower, exc):
21852210
),
21862211
],
21872212
)
2188-
def test_Solve(A, x, lower, exc):
2189-
g = slinalg.Solve(lower)(A, x)
2213+
def test_SolveTriangular(A, x, lower, exc):
2214+
g = slinalg.SolveTriangular(lower)(A, x)
21902215

21912216
if isinstance(g, list):
21922217
g_fg = FunctionGraph(outputs=g)

0 commit comments

Comments
 (0)