Skip to content

Commit fd26b74

Browse files
Split block_diag into sparse and dense version
Closely follow scipy function signature for `block_diag`
1 parent d809f1c commit fd26b74

File tree

4 files changed

+141
-78
lines changed

4 files changed

+141
-78
lines changed

pytensor/sparse/basic.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
TODO: Automatic methods for determining best sparse format?
88
99
"""
10+
from typing import Literal
1011
from warnings import warn
1112

1213
import numpy as np
@@ -47,6 +48,7 @@
4748
trunc,
4849
)
4950
from pytensor.tensor.shape import shape, specify_broadcastable
51+
from pytensor.tensor.slinalg import BaseBlockDiagonal, largest_common_dtype
5052
from pytensor.tensor.type import TensorType
5153
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
5254
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
@@ -60,7 +62,6 @@
6062

6163
sparse_formats = ["csc", "csr"]
6264

63-
6465
"""
6566
Types of sparse matrices to use for testing.
6667
@@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):
183184

184185
as_sparse = as_sparse_variable
185186

186-
187187
as_sparse_or_tensor_variable = as_symbolic
188188

189189

@@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
18001800
return r
18011801

18021802
def __str__(self):
1803-
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
1803+
return f"{self.__class__.__name__}{{axis={self.axis}}}"
18041804

18051805

18061806
def sp_sum(x, axis=None, sparse_grad=False):
@@ -2775,19 +2775,14 @@ def comparison(self, x, y):
27752775

27762776
greater_equal_s_d = GreaterEqualSD()
27772777

2778-
27792778
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
27802779

2781-
27822780
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
27832781

2784-
27852782
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
27862783

2787-
27882784
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
27892785

2790-
27912786
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
27922787

27932788
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
@@ -2992,7 +2987,7 @@ def __str__(self):
29922987
l = []
29932988
if self.inplace:
29942989
l.append("inplace")
2995-
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
2990+
return f"{self.__class__.__name__}{{{', '.join(l)}}}"
29962991

29972992
def make_node(self, x):
29982993
"""
@@ -3291,6 +3286,7 @@ class TrueDot(Op):
32913286
# Simplify code by splitting into DotSS and DotSD.
32923287

32933288
__props__ = ()
3289+
32943290
# The grad_preserves_dense attribute doesn't change the
32953291
# execution behavior. To let the optimizer merge nodes with
32963292
# different values of this attribute we shouldn't compare it
@@ -4260,3 +4256,76 @@ def grad(self, inputs, grads):
42604256

42614257

42624258
construct_sparse_from_list = ConstructSparseFromList()
4259+
4260+
4261+
class SparseBlockDiagonalMatrix(BaseBlockDiagonal):
4262+
def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None):
4263+
if not matrices:
4264+
raise ValueError("no matrices to allocate")
4265+
dtype = largest_common_dtype(matrices)
4266+
matrices = list(map(pytensor.tensor.as_tensor, matrices))
4267+
4268+
if any(mat.type.ndim != 2 for mat in matrices):
4269+
raise TypeError("all data arguments must be matrices")
4270+
4271+
out_type = matrix(format=format, dtype=dtype, name=name)
4272+
return Apply(self, matrices, [out_type])
4273+
4274+
def perform(self, node, inputs, output_storage, params=None):
4275+
format = node.outputs[0].type.format
4276+
dtype = largest_common_dtype(inputs)
4277+
output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype(
4278+
dtype
4279+
)
4280+
4281+
4282+
_sparse_block_diagonal = SparseBlockDiagonalMatrix()
4283+
4284+
4285+
def block_diag(
4286+
*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc", name=None
4287+
):
4288+
r"""
4289+
Construct a block diagonal matrix from a sequence of input matrices.
4290+
4291+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
4292+
4293+
[[A, 0, 0],
4294+
[0, B, 0],
4295+
[0, 0, C]]
4296+
4297+
Parameters
4298+
----------
4299+
A, B, C ... : tensors
4300+
Input sparse matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions,
4301+
and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix.
4302+
format: str, optional
4303+
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
4304+
name: str, optional
4305+
Name of the output tensor.
4306+
4307+
Returns
4308+
-------
4309+
out: sparse matrix tensor
4310+
Symbolic sparse matrix in the specified format.
4311+
4312+
Examples
4313+
--------
4314+
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
4315+
4316+
..code-block:: python
4317+
import numpy as np
4318+
from pytensor.sparse import block_diag
4319+
from scipy.sparse import csr_matrix
4320+
4321+
A = csr_matrix([[1, 2], [3, 4]])
4322+
B = csr_matrix([[5, 6], [7, 8]])
4323+
result_sparse = block_diag(A, B, format='csr', name='X')
4324+
print(result_sparse.eval())
4325+
4326+
The resulting sparse block diagonal matrix `result_sparse` is in CSR format.
4327+
"""
4328+
if len(matrices) == 1:
4329+
return matrices
4330+
4331+
return _sparse_block_diagonal(*matrices, format=format, name=name)

pytensor/tensor/slinalg.py

Lines changed: 48 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -912,77 +912,72 @@ def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
912912
return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])
913913

914914

915-
class BlockDiagonalMatrix(Op):
916-
__props__ = ("sparse", "format")
915+
def block_diag_grad(inputs, gout):
916+
shapes = pt.stack([i.shape for i in inputs])
917+
index_end = shapes.cumsum(0)
918+
index_begin = index_end - shapes
919+
slices = [
920+
ptb.ix_(
921+
pt.arange(index_begin[i, 0], index_end[i, 0]),
922+
pt.arange(index_begin[i, 1], index_end[i, 1]),
923+
)
924+
for i in range(len(inputs))
925+
]
926+
return [gout[0][slc] for slc in slices]
927+
928+
929+
class BaseBlockDiagonal(Op):
930+
def grad(self, inputs, gout):
931+
return block_diag_grad(inputs, gout)
932+
933+
def infer_shape(self, fgraph, nodes, shapes):
934+
first, second = zip(*shapes)
935+
return [(pt.add(*first), pt.add(*second))]
917936

918-
def __init__(self, sparse=False, format="csr"):
919-
if format not in ("csr", "csc"):
920-
raise ValueError(f"format must be one of: 'csr', 'csc', got {format}")
921-
self.sparse = sparse
922-
self.format = format
923937

924-
def make_node(self, *matrices):
938+
class BlockDiagonalMatrix(BaseBlockDiagonal):
939+
def make_node(self, *matrices, name=None):
925940
if not matrices:
926941
raise ValueError("no matrices to allocate")
927942
dtype = largest_common_dtype(matrices)
928943
matrices = list(map(pt.as_tensor, matrices))
929944

930945
if any(mat.type.ndim != 2 for mat in matrices):
931946
raise TypeError("all data arguments must be matrices")
932-
if self.sparse:
933-
out_type = pytensor.sparse.matrix(self.format, dtype=dtype)
934-
else:
935-
out_type = pytensor.tensor.matrix(dtype=dtype)
947+
948+
out_type = pytensor.tensor.matrix(dtype=dtype, name=name)
936949
return Apply(self, matrices, [out_type])
937950

938951
def perform(self, node, inputs, output_storage, params=None):
939952
dtype = largest_common_dtype(inputs)
940-
if self.sparse:
941-
output_storage[0][0] = scipy.sparse.block_diag(inputs, self.format, dtype)
942-
else:
943-
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
953+
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)
944954

945-
def grad(self, inputs, gout):
946-
shapes = pt.stack([i.shape for i in inputs])
947-
index_end = shapes.cumsum(0)
948-
index_begin = index_end - shapes
949-
slices = [
950-
ptb.ix_(
951-
pt.arange(index_begin[i, 0], index_end[i, 0]),
952-
pt.arange(index_begin[i, 1], index_end[i, 1]),
953-
)
954-
for i in range(len(inputs))
955-
]
956-
return [gout[0][slc] for slc in slices]
957955

958-
def infer_shape(self, fgraph, nodes, shapes):
959-
first, second = zip(*shapes)
960-
return [(pt.add(*first), pt.add(*second))]
956+
_block_diagonal_matrix = BlockDiagonalMatrix()
961957

962958

963-
def block_diagonal(
964-
matrices: typing.Sequence[TensorVariable],
965-
sparse: bool = False,
966-
format: Literal["csr", "csc"] = "csr",
967-
):
959+
def block_diag(*matrices: TensorVariable, name=None):
968960
"""
969-
Construct a block diagonal matrix from a sequence of input matrices.
961+
Construct a block diagonal matrix from a sequence of input tensors.
962+
963+
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:
964+
965+
[[A, 0, 0],
966+
[0, B, 0],
967+
[0, 0, C]]
970968
971969
Parameters
972970
----------
973-
matrices: sequence of tensors
971+
A, B, C ... : tensors
974972
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the
975-
block diagonal matrix will be formed along the first axis of the matrices.
976-
sparse : bool, optional
977-
If True, the function returns a sparse matrix in the specified format. Default is True.
978-
format: str, optional
979-
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
973+
block diagonal matrix will be formed using the right-most two dimensions of each input matrix.
974+
name: str, optional
975+
Name of the block diagonal matrix.
980976
981977
Returns
982978
-------
983-
out: tensor or sparse matrix tensor
984-
The block diagonal matrix formed from the input matrices. If `sparse` is True, the output will be a symbolic
985-
sparse matrix in the specified format. Otherwise, a symbolic tensor will be returned.
979+
out: tensor
980+
The block diagonal matrix formed from the input matrices.
986981
987982
Examples
988983
--------
@@ -991,30 +986,21 @@ def block_diagonal(
991986
..code-block:: python
992987
993988
import numpy as np
994-
from pytensor.tensor.slinalg import block_diagonal
989+
from pytensor.tensor.slinalg import block_diag
995990
996-
matrices = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])]
997-
matrices = [pt.as_tensor_variable(mat) for mat in matrices]
998-
result = block_diagonal(matrices)
991+
A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
992+
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))
999993
994+
result = block_diagonal(A, B, name='X')
1000995
print(result.eval())
1001996
>>> Out: array([[1, 2, 0, 0],
1002997
>>> [3, 4, 0, 0],
1003998
>>> [0, 0, 5, 6],
1004999
>>> [0, 0, 7, 8]])
1005-
1006-
Create a sparse block diagonal matrix from two sparse 2x2 matrices:
1007-
1008-
..code-block:: python
1009-
1010-
matrices_sparse = [csr_matrix([[1, 2], [3, 4]]), csr_matrix([[5, 6], [7, 8]])]
1011-
result_sparse = block_diagonal(matrices_sparse, sparse=True)
1012-
1013-
The resulting sparse block diagonal matrix `result_sparse` is in CSR format.
10141000
"""
10151001
if len(matrices) == 1: # graph optimization
1016-
return matrices[0]
1017-
return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices)
1002+
return matrices
1003+
return _block_diagonal_matrix(*matrices, name=name)
10181004

10191005

10201006
__all__ = [
@@ -1027,5 +1013,5 @@ def block_diagonal(
10271013
"solve_continuous_lyapunov",
10281014
"solve_discrete_are",
10291015
"solve_triangular",
1030-
"block_diagonal",
1016+
"block_diag",
10311017
]

tests/sparse/test_basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
add_s_s_data,
5252
as_sparse_or_tensor_variable,
5353
as_sparse_variable,
54+
block_diag,
5455
cast,
5556
clean,
5657
construct_sparse_from_list,
@@ -3389,3 +3390,15 @@ def _helper(x, y):
33893390
)
33903391
class TestSharedOptions:
33913392
pass
3393+
3394+
3395+
@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"])
3396+
def test_block_diagonal(format):
3397+
from scipy.sparse import block_diag as scipy_block_diag
3398+
3399+
matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])]
3400+
result = block_diag(*matrices, format=format, name="X")
3401+
sp_result = scipy_block_diag(matrices, format=format)
3402+
3403+
assert isinstance(result.eval(), type(sp_result))
3404+
np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray())

tests/tensor/test_slinalg.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Solve,
1616
SolveBase,
1717
SolveTriangular,
18-
block_diagonal,
18+
block_diag,
1919
cho_solve,
2020
cholesky,
2121
eigvalsh,
@@ -666,10 +666,5 @@ def test_solve_discrete_are_grad():
666666

667667
def test_block_diagonal():
668668
matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])]
669-
result = block_diagonal(matrices)
669+
result = block_diag(*matrices)
670670
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices))
671-
672-
result = block_diagonal(matrices, format="csr", sparse=True)
673-
sp_result = scipy.sparse.block_diag(matrices, format="csr")
674-
assert isinstance(result.eval(), type(sp_result))
675-
np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray())

0 commit comments

Comments
 (0)