Skip to content

Commit 86282bd

Browse files
Update aesara.tensor.slinalg.Solve to match SciPy interface
1 parent a6e461b commit 86282bd

File tree

8 files changed

+174
-102
lines changed

8 files changed

+174
-102
lines changed

aesara/link/jax/dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def cholesky(a, lower=lower):
800800
@jax_funcify.register(Solve)
801801
def jax_funcify_Solve(op, **kwargs):
802802

803-
if op.A_structure == "lower_triangular":
803+
if op.assume_a != "gen" and op.lower:
804804
lower = True
805805
else:
806806
lower = False

aesara/link/numba/dispatch.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,9 +1690,12 @@ def cholesky(a):
16901690
@numba_funcify.register(Solve)
16911691
def numba_funcify_Solve(op, node, **kwargs):
16921692

1693-
if op.A_structure == "lower_triangular" or op.A_structure == "upper_triangular":
1693+
assume_a = op.assume_a
1694+
# check_finite = op.check_finite
16941695

1695-
lower = op.A_structure == "lower_triangular"
1696+
if assume_a != "gen":
1697+
1698+
lower = op.lower
16961699

16971700
warnings.warn(
16981701
(
@@ -1707,16 +1710,26 @@ def numba_funcify_Solve(op, node, **kwargs):
17071710
@numba.njit
17081711
def solve(a, b):
17091712
with numba.objmode(ret=ret_sig):
1710-
ret = scipy.linalg.solve_triangular(a, b, lower=lower)
1713+
ret = scipy.linalg.solve_triangular(
1714+
a,
1715+
b,
1716+
lower=lower,
1717+
# check_finite=check_finite
1718+
)
17111719
return ret
17121720

17131721
else:
17141722
out_dtype = node.outputs[0].type.numpy_dtype
17151723
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
17161724

1717-
@numba.njit
1725+
@numba.njit(inline="always")
17181726
def solve(a, b):
1719-
return np.linalg.solve(inputs_cast(a), inputs_cast(b)).astype(out_dtype)
1727+
return np.linalg.solve(
1728+
inputs_cast(a),
1729+
inputs_cast(b),
1730+
# assume_a=assume_a,
1731+
# check_finite=check_finite,
1732+
).astype(out_dtype)
17201733

17211734
return solve
17221735

aesara/sandbox/linalg/ops.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,25 @@ def tag_solve_triangular(fgraph, node):
249249
replace it with a triangular solve.
250250
251251
"""
252-
if node.op == solve:
253-
if node.op.A_structure == "general":
252+
if isinstance(node.op, Solve):
253+
if node.op.assume_a == "gen":
254254
A, b = node.inputs # result is solution Ax=b
255-
if A.owner and isinstance(A.owner.op, type(cholesky)):
255+
if A.owner and isinstance(A.owner.op, Cholesky):
256256
if A.owner.op.lower:
257-
return [Solve("lower_triangular")(A, b)]
257+
return [Solve(assume_a="sym", lower=True)(A, b)]
258258
else:
259-
return [Solve("upper_triangular")(A, b)]
259+
return [Solve(assume_a="sym", lower=False)(A, b)]
260260
if (
261261
A.owner
262262
and isinstance(A.owner.op, DimShuffle)
263263
and A.owner.op.new_order == (1, 0)
264264
):
265265
(A_T,) = A.owner.inputs
266-
if A_T.owner and isinstance(A_T.owner.op, type(cholesky)):
266+
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
267267
if A_T.owner.op.lower:
268-
return [Solve("upper_triangular")(A, b)]
268+
return [Solve(assume_a="sym", lower=False)(A, b)]
269269
else:
270-
return [Solve("lower_triangular")(A, b)]
270+
return [Solve(assume_a="sym", lower=True)(A, b)]
271271

272272

273273
@register_canonicalize
@@ -286,15 +286,15 @@ def no_transpose_symmetric(fgraph, node):
286286
@register_stabilize
287287
@local_optimizer(None) # XXX: solve is defined later and can't be used here
288288
def psd_solve_with_chol(fgraph, node):
289-
if node.op == solve:
289+
if isinstance(node.op, Solve):
290290
A, b = node.inputs # result is solution Ax=b
291291
if is_psd(A):
292292
L = cholesky(A)
293293
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
294294
# __if__ no other Op makes use of the the L matrix during the
295295
# stabilization
296-
Li_b = Solve("lower_triangular")(L, b)
297-
x = Solve("upper_triangular")(L.T, Li_b)
296+
Li_b = Solve(assume_a="sym", lower=True)(L, b)
297+
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
298298
return [x]
299299

300300

aesara/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _as_tensor_variable(
5959
nlinalg,
6060
nnet,
6161
opt_uncanonicalize,
62+
slinalg,
6263
xlogx,
6364
)
6465
from aesara.tensor.basic import *

aesara/tensor/slinalg.py

Lines changed: 112 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,16 @@
55
import scipy.linalg
66

77
import aesara.tensor
8-
import aesara.tensor.basic as aet
9-
import aesara.tensor.math as tm
108
from aesara.graph.basic import Apply
119
from aesara.graph.op import Op
1210
from aesara.tensor import as_tensor_variable
11+
from aesara.tensor import basic as aet
12+
from aesara.tensor import math as atm
1313
from aesara.tensor.type import matrix, tensor, vector
1414

1515

1616
logger = logging.getLogger(__name__)
1717

18-
MATRIX_STRUCTURES = (
19-
"general",
20-
"symmetric",
21-
"lower_triangular",
22-
"upper_triangular",
23-
"hermitian",
24-
"banded",
25-
"diagonal",
26-
"toeplitz",
27-
)
28-
2918

3019
class Cholesky(Op):
3120
"""
@@ -95,7 +84,7 @@ def L_op(self, inputs, outputs, gradients):
9584
# Replace the cholesky decomposition with 1 if there are nans
9685
# or solve_upper_triangular will throw a ValueError.
9786
if self.on_error == "nan":
98-
ok = ~tm.any(tm.isnan(chol_x))
87+
ok = ~atm.any(atm.isnan(chol_x))
9988
chol_x = aet.switch(ok, chol_x, 1)
10089
dz = aet.switch(ok, dz, 1)
10190

@@ -206,17 +195,24 @@ class Solve(Op):
206195
For on CPU and GPU.
207196
"""
208197

209-
__props__ = ("A_structure", "lower", "overwrite_A", "overwrite_b")
198+
__props__ = (
199+
"assume_a",
200+
"lower",
201+
"check_finite", # "transposed"
202+
)
210203

211204
def __init__(
212-
self, A_structure="general", lower=False, overwrite_A=False, overwrite_b=False
205+
self,
206+
assume_a="gen",
207+
lower=False,
208+
check_finite=True, # transposed=False
213209
):
214-
if A_structure not in MATRIX_STRUCTURES:
215-
raise ValueError("Invalid matrix structure argument", A_structure)
216-
self.A_structure = A_structure
210+
if assume_a not in ("gen", "sym", "her", "pos"):
211+
raise ValueError(f"{assume_a} is not a recognized matrix structure")
212+
self.assume_a = assume_a
217213
self.lower = lower
218-
self.overwrite_A = overwrite_A
219-
self.overwrite_b = overwrite_b
214+
self.check_finite = check_finite
215+
# self.transposed = transposed
220216

221217
def __repr__(self):
222218
return "Solve{%s}" % str(self._props())
@@ -237,12 +233,33 @@ def make_node(self, A, b):
237233

238234
def perform(self, node, inputs, output_storage):
239235
A, b = inputs
240-
if self.A_structure == "lower_triangular":
241-
rval = scipy.linalg.solve_triangular(A, b, lower=True)
242-
elif self.A_structure == "upper_triangular":
243-
rval = scipy.linalg.solve_triangular(A, b, lower=False)
236+
237+
if self.assume_a != "gen":
238+
# if self.transposed:
239+
# if self.assume_a == "her":
240+
# trans = "C"
241+
# else:
242+
# trans = "T"
243+
# else:
244+
# trans = "N"
245+
246+
rval = scipy.linalg.solve_triangular(
247+
A,
248+
b,
249+
lower=self.lower,
250+
check_finite=self.check_finite,
251+
# trans=trans
252+
)
244253
else:
245-
rval = scipy.linalg.solve(A, b)
254+
rval = scipy.linalg.solve(
255+
A,
256+
b,
257+
assume_a=self.assume_a,
258+
lower=self.lower,
259+
check_finite=self.check_finite,
260+
# transposed=self.transposed,
261+
)
262+
246263
output_storage[0][0] = rval
247264

248265
# computes shape of x where x = inv(A) * b
@@ -257,7 +274,7 @@ def infer_shape(self, fgraph, node, shapes):
257274

258275
def L_op(self, inputs, outputs, output_gradients):
259276
r"""
260-
Reverse-mode gradient updates for matrix solve operation c = A \\\ b.
277+
Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
261278
262279
Symbolic expression for updates taken from [#]_.
263280
@@ -269,53 +286,84 @@ def L_op(self, inputs, outputs, output_gradients):
269286
270287
"""
271288
A, b = inputs
289+
272290
c = outputs[0]
291+
# C is a scalar representing the entire graph
292+
# `output_gradients` is (dC/dc,)
293+
# We need to return (dC/d[inv(A)], dC/db)
273294
c_bar = output_gradients[0]
274-
trans_map = {
275-
"lower_triangular": "upper_triangular",
276-
"upper_triangular": "lower_triangular",
277-
}
295+
278296
trans_solve_op = Solve(
279-
# update A_structure and lower to account for a transpose operation
280-
A_structure=trans_map.get(self.A_structure, self.A_structure),
297+
assume_a=self.assume_a,
298+
check_finite=self.check_finite,
281299
lower=not self.lower,
282300
)
283301
b_bar = trans_solve_op(A.T, c_bar)
284302
# force outer product if vector second input
285-
A_bar = -tm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
286-
if self.A_structure == "lower_triangular":
287-
A_bar = aet.tril(A_bar)
288-
elif self.A_structure == "upper_triangular":
289-
A_bar = aet.triu(A_bar)
303+
A_bar = -atm.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
304+
305+
if self.assume_a != "gen":
306+
if self.lower:
307+
A_bar = aet.tril(A_bar)
308+
else:
309+
A_bar = aet.triu(A_bar)
310+
290311
return [A_bar, b_bar]
291312

292313

293314
solve = Solve()
294-
"""
295-
Solves the equation ``a x = b`` for x, where ``a`` is a matrix and
296-
``b`` can be either a vector or a matrix.
297-
298-
Parameters
299-
----------
300-
a : `(M, M) symbolix matrix`
301-
A square matrix
302-
b : `(M,) or (M, N) symbolic vector or matrix`
303-
Right hand side matrix in ``a x = b``
304-
305-
306-
Returns
307-
-------
308-
x : `(M, ) or (M, N) symbolic vector or matrix`
309-
x will have the same shape as b
310-
"""
311-
# lower and upper triangular solves
312-
solve_lower_triangular = Solve(A_structure="lower_triangular", lower=True)
313-
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is lower triangular."""
314-
solve_upper_triangular = Solve(A_structure="upper_triangular", lower=False)
315-
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is upper triangular."""
316-
# symmetric solves
317-
solve_symmetric = Solve(A_structure="symmetric")
318-
"""Optimized implementation of :func:`aesara.tensor.slinalg.solve` when A is symmetric."""
315+
316+
317+
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
318+
"""
319+
Solves the linear equation set ``a * x = b`` for the unknown ``x``
320+
for square ``a`` matrix.
321+
322+
If the data matrix is known to be a particular type then supplying the
323+
corresponding string to ``assume_a`` key chooses the dedicated solver.
324+
The available options are
325+
326+
=================== ========
327+
generic matrix 'gen'
328+
symmetric 'sym'
329+
hermitian 'her'
330+
positive definite 'pos'
331+
=================== ========
332+
333+
If omitted, ``'gen'`` is the default structure.
334+
335+
The datatype of the arrays define which solver is called regardless
336+
of the values. In other words, even when the complex array entries have
337+
precisely zero imaginary parts, the complex solver will be called based
338+
on the data type of the array.
339+
340+
Parameters
341+
----------
342+
a : (N, N) array_like
343+
Square input data
344+
b : (N, NRHS) array_like
345+
Input data for the right hand side.
346+
lower : bool, optional
347+
If True, only the data contained in the lower triangle of `a`. Default
348+
is to use upper triangle. (ignored for ``'gen'``)
349+
check_finite : bool, optional
350+
Whether to check that the input matrices contain only finite numbers.
351+
Disabling may give a performance gain, but may result in problems
352+
(crashes, non-termination) if the inputs do contain infinities or NaNs.
353+
assume_a : str, optional
354+
Valid entries are explained above.
355+
"""
356+
return Solve(
357+
lower=lower,
358+
check_finite=check_finite,
359+
assume_a=assume_a,
360+
)(a, b)
361+
362+
363+
# TODO: These are deprecated; emit a warning
364+
solve_lower_triangular = Solve(assume_a="sym", lower=True)
365+
solve_upper_triangular = Solve(assume_a="sym", lower=False)
366+
solve_symmetric = Solve(assume_a="sym")
319367

320368
# TODO: Optimizations to replace multiplication by matrix inverse
321369
# with solve() Op (still unwritten)
@@ -456,7 +504,7 @@ def kron(a, b):
456504
"kron: inputs dimensions must sum to 3 or more. "
457505
f"You passed {int(a.ndim)} and {int(b.ndim)}."
458506
)
459-
o = tm.outer(a, b)
507+
o = atm.outer(a, b)
460508
o = o.reshape(aet.concatenate((a.shape, b.shape)), a.ndim + b.ndim)
461509
shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim)))
462510
if shf.ndim == 3:

tests/link/test_numba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,7 +2000,7 @@ def test_Cholesky(x, lower, exc):
20002000
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
20012001
),
20022002
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
2003-
"general",
2003+
"gen",
20042004
None,
20052005
),
20062006
(
@@ -2011,7 +2011,7 @@ def test_Cholesky(x, lower, exc):
20112011
),
20122012
),
20132013
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
2014-
"general",
2014+
"gen",
20152015
None,
20162016
),
20172017
(
@@ -2020,7 +2020,7 @@ def test_Cholesky(x, lower, exc):
20202020
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
20212021
),
20222022
set_test_value(aet.dvector(), rng.random(size=(3,)).astype("float64")),
2023-
"lower_triangular",
2023+
"sym",
20242024
UserWarning,
20252025
),
20262026
],

0 commit comments

Comments
 (0)