Skip to content

Commit c4ae6e3

Browse files
Add linalg.block_diag and sparse equivalent (#576)
* Copy `block_diag` and support functions from `pymc.math` * Evaluate output in sphinx code example Co-authored-by: Ricardo Vieira <[email protected]> * Test type equivalence with `isinstance` instead of `==` Co-authored-by: Ricardo Vieira <[email protected]> * Typo in test function * Split `block_diag` into sparse and dense version Closely follow scipy function signature for `block_diag` * Use `as_sparse_or_tensor_variable` in `SparseBlockDiagonalMatrix` to allow sparse matrix inputs to `pytensor.sparse.block_diag` * Test sparse and dense inputs to `pytensor.sparse.block_diag` * Add numba overload for `pytensor.tensor.slinalg.block_diag` * add jax overload for `pytensor.tensor.slinalg.block_diag` * Move stand-alone `block_diag_grad` function into `grad` method * Add `format` prop to `SparseBlockDiagonalMatrix` * Use `compare_numba_and_py` in `numba\test_slinalg.py::test_block_diag` * Add support for Blockwise to `slinalg.block_diag` * Add gradient test Remove `Matrix` from `BlockDiagonal` and `SparseBlockDiagonal` `Op` names Correct errors in docstrings Move input validation to a shared class method * Remove `gufunc_signature` from `__props__` Co-authored-by: Ricardo Vieira <[email protected]> * Implement correct `__props__` for subclasses of `BaseBlockMatrix` --------- Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 96f753b commit c4ae6e3

File tree

9 files changed

+348
-13
lines changed

9 files changed

+348
-13
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
4-
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
4+
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular
55

66

77
@jax_funcify.register(Cholesky)
@@ -45,3 +45,11 @@ def solve_triangular(A, b):
4545
)
4646

4747
return solve_triangular
48+
49+
50+
@jax_funcify.register(BlockDiagonal)
51+
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
52+
def block_diag(*inputs):
53+
return jax.scipy.linalg.block_diag(*inputs)
54+
55+
return block_diag

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.link.numba.dispatch import basic as numba_basic
1111
from pytensor.link.numba.dispatch.basic import numba_funcify
12-
from pytensor.tensor.slinalg import SolveTriangular
12+
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular
1313

1414

1515
_PTR = ctypes.POINTER
@@ -273,3 +273,25 @@ def solve_triangular(a, b):
273273
return res
274274

275275
return solve_triangular
276+
277+
278+
@numba_funcify.register(BlockDiagonal)
279+
def numba_funcify_BlockDiagonal(op, node, **kwargs):
280+
dtype = node.outputs[0].dtype
281+
282+
# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
283+
@numba_basic.numba_njit(inline="never")
284+
def block_diag(*arrs):
285+
shapes = np.array([a.shape for a in arrs], dtype="int")
286+
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
287+
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
288+
289+
r, c = 0, 0
290+
for arr, shape in zip(arrs, shapes):
291+
rr, cc = shape
292+
out[r : r + rr, c : c + cc] = arr
293+
r += rr
294+
c += cc
295+
return out
296+
297+
return block_diag

pytensor/sparse/basic.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TODO: Automatic methods for determining best sparse format?
88
99
"""
10+
from typing import Literal
1011
from warnings import warn
1112

1213
import numpy as np
@@ -47,6 +48,7 @@
4748
trunc,
4849
)
4950
from pytensor.tensor.shape import shape, specify_broadcastable
51+
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
5052
from pytensor.tensor.type import TensorType
5153
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
5254
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
@@ -60,7 +62,6 @@
6062

6163
sparse_formats = ["csc", "csr"]
6264

63-
6465
"""
6566
Types of sparse matrices to use for testing.
6667
@@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
183184

184185
as_sparse = as_sparse_variable
185186

186-
187187
as_sparse_or_tensor_variable = as_symbolic
188188

189189

@@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
18001800
return r
18011801

18021802
def __str__(self):
1803-
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
1803+
return f"{self.__class__.__name__}{{axis={self.axis}}}"
18041804

18051805

18061806
def sp_sum(x, axis=None, sparse_grad=False):
@@ -2775,19 +2775,14 @@ def comparison(self, x, y):
27752775

27762776
greater_equal_s_d = GreaterEqualSD()
27772777

2778-
27792778
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
27802779

2781-
27822780
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
27832781

2784-
27852782
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
27862783

2787-
27882784
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
27892785

2790-
27912786
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
27922787

27932788
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
@@ -2992,7 +2987,7 @@ def __str__(self):
29922987
l = []
29932988
if self.inplace:
29942989
l.append("inplace")
2995-
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
2990+
return f"{self.__class__.__name__}{{{', '.join(l)}}}"
29962991

29972992
def make_node(self, x):
29982993
"""
@@ -3291,6 +3286,7 @@ class TrueDot(Op):
32913286
# Simplify code by splitting into DotSS and DotSD.
32923287

32933288
__props__ = ()
3289+
32943290
# The grad_preserves_dense attribute doesn't change the
32953291
# execution behavior. To let the optimizer merge nodes with
32963292
# different values of this attribute we shouldn't compare it
@@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):
42604256

42614257

42624258
construct_sparse_from_list = ConstructSparseFromList()
4259+
4260+
4261+
class SparseBlockDiagonal(BaseBlockDiagonal):
4262+
__props__ = (
4263+
"n_inputs",
4264+
"format",
4265+
)
4266+
4267+
def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
4268+
super().__init__(n_inputs)
4269+
self.format = format
4270+
4271+
def make_node(self, *matrices):
4272+
matrices = self._validate_and_prepare_inputs(
4273+
matrices, as_sparse_or_tensor_variable
4274+
)
4275+
dtype = _largest_common_dtype(matrices)
4276+
out_type = matrix(format=self.format, dtype=dtype)
4277+
4278+
return Apply(self, matrices, [out_type])
4279+
4280+
def perform(self, node, inputs, output_storage, params=None):
4281+
dtype = node.outputs[0].type.dtype
4282+
output_storage[0][0] = scipy.sparse.block_diag(
4283+
inputs, format=self.format
4284+
).astype(dtype)
4285+
4286+
4287+
def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
4288+
r"""
4289+
Construct a block diagonal matrix from a sequence of input matrices.
4290+
4291+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
4292+
4293+
[[A, 0, 0],
4294+
[0, B, 0],
4295+
[0, 0, C]]
4296+
4297+
Parameters
4298+
----------
4299+
A, B, C ... : tensors
4300+
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
4301+
inputs should have at least 2 dimensins.
4302+
4303+
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
4304+
requested format if they are not.
4305+
4306+
format: str, optional
4307+
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
4308+
4309+
Returns
4310+
-------
4311+
out: sparse matrix tensor
4312+
Symbolic sparse matrix in the specified format.
4313+
4314+
Examples
4315+
--------
4316+
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
4317+
4318+
..code-block:: python
4319+
import numpy as np
4320+
from pytensor.sparse import block_diag
4321+
from scipy.sparse import csr_matrix
4322+
4323+
A = csr_matrix([[1, 2], [3, 4]])
4324+
B = csr_matrix([[5, 6], [7, 8]])
4325+
result_sparse = block_diag(A, B, format='csr', name='X')
4326+
4327+
print(result_sparse)
4328+
>>> SparseVariable{csr,int32}
4329+
4330+
print(result_sparse.toarray().eval())
4331+
>>> array([[1, 2, 0, 0],
4332+
>>> [3, 4, 0, 0],
4333+
>>> [0, 0, 5, 6],
4334+
>>> [0, 0, 7, 8]])
4335+
"""
4336+
if len(matrices) == 1:
4337+
return matrices
4338+
4339+
_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
4340+
return _sparse_block_diagonal(*matrices)

pytensor/tensor/basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4279,6 +4279,25 @@ def take_along_axis(arr, indices, axis=0):
42794279
return arr[_make_along_axis_idx(arr.shape, indices, axis)]
42804280

42814281

4282+
def ix_(*args):
4283+
"""
4284+
PyTensor np.ix_ analog
4285+
4286+
See numpy.lib.index_tricks.ix_ for reference
4287+
"""
4288+
out = []
4289+
nd = len(args)
4290+
for k, new in enumerate(args):
4291+
if new is None:
4292+
out.append(slice(None))
4293+
new = as_tensor(new)
4294+
if new.ndim != 1:
4295+
raise ValueError("Cross index must be 1 dimensional")
4296+
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
4297+
out.append(new)
4298+
return tuple(out)
4299+
4300+
42824301
__all__ = [
42834302
"take_along_axis",
42844303
"expand_dims",

pytensor/tensor/slinalg.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import typing
33
import warnings
4+
from functools import reduce
45
from typing import TYPE_CHECKING, Literal, Optional, Union
56

67
import numpy as np
@@ -23,7 +24,6 @@
2324
if TYPE_CHECKING:
2425
from pytensor.tensor import TensorLike
2526

26-
2727
logger = logging.getLogger(__name__)
2828

2929

@@ -908,6 +908,107 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
908908
)
909909

910910

911+
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
912+
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
913+
914+
915+
class BaseBlockDiagonal(Op):
916+
__props__ = ("n_inputs",)
917+
918+
def __init__(self, n_inputs):
919+
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)])
920+
self.gufunc_signature = f"{input_sig}->(m,n)"
921+
922+
if n_inputs == 0:
923+
raise ValueError("n_inputs must be greater than 0")
924+
self.n_inputs = n_inputs
925+
926+
def grad(self, inputs, gout):
927+
shapes = pt.stack([i.shape for i in inputs])
928+
index_end = shapes.cumsum(0)
929+
index_begin = index_end - shapes
930+
slices = [
931+
ptb.ix_(
932+
pt.arange(index_begin[i, 0], index_end[i, 0]),
933+
pt.arange(index_begin[i, 1], index_end[i, 1]),
934+
)
935+
for i in range(len(inputs))
936+
]
937+
return [gout[0][slc] for slc in slices]
938+
939+
def infer_shape(self, fgraph, nodes, shapes):
940+
first, second = zip(*shapes)
941+
return [(pt.add(*first), pt.add(*second))]
942+
943+
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
944+
if len(matrices) != self.n_inputs:
945+
raise ValueError(
946+
f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}"
947+
)
948+
matrices = list(map(as_tensor_func, matrices))
949+
if any(mat.type.ndim != 2 for mat in matrices):
950+
raise TypeError("All inputs must have dimension 2")
951+
return matrices
952+
953+
954+
class BlockDiagonal(BaseBlockDiagonal):
955+
__props__ = ("n_inputs",)
956+
957+
def make_node(self, *matrices):
958+
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
959+
dtype = _largest_common_dtype(matrices)
960+
out_type = pytensor.tensor.matrix(dtype=dtype)
961+
return Apply(self, matrices, [out_type])
962+
963+
def perform(self, node, inputs, output_storage, params=None):
964+
dtype = node.outputs[0].type.dtype
965+
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
966+
967+
968+
def block_diag(*matrices: TensorVariable):
969+
"""
970+
Construct a block diagonal matrix from a sequence of input tensors.
971+
972+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
973+
974+
[[A, 0, 0],
975+
[0, B, 0],
976+
[0, 0, C]]
977+
978+
Parameters
979+
----------
980+
A, B, C ... : tensors
981+
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
982+
inputs should have at least 2 dimensins.
983+
984+
Returns
985+
-------
986+
out: tensor
987+
The block diagonal matrix formed from the input matrices.
988+
989+
Examples
990+
--------
991+
Create a block diagonal matrix from two 2x2 matrices:
992+
993+
..code-block:: python
994+
995+
import numpy as np
996+
from pytensor.tensor.linalg import block_diag
997+
998+
A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
999+
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))
1000+
1001+
result = block_diagonal(A, B, name='X')
1002+
print(result.eval())
1003+
>>> Out: array([[1, 2, 0, 0],
1004+
>>> [3, 4, 0, 0],
1005+
>>> [0, 0, 5, 6],
1006+
>>> [0, 0, 7, 8]])
1007+
"""
1008+
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
1009+
return _block_diagonal_matrix(*matrices)
1010+
1011+
9111012
__all__ = [
9121013
"cholesky",
9131014
"solve",
@@ -918,4 +1019,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
9181019
"solve_continuous_lyapunov",
9191020
"solve_discrete_are",
9201021
"solve_triangular",
1022+
"block_diag",
9211023
]

0 commit comments

Comments
 (0)