Skip to content

Add linalg.block_diag and sparse equivalent #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4269,6 +4269,25 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr.shape, indices, axis)]


def ix_(*args):
"""
PyTensor np.ix_ analog

See numpy.lib.index_tricks.ix_ for reference
"""
out = []
nd = len(args)
for k, new in enumerate(args):
if new is None:
out.append(slice(None))
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)


__all__ = [
"take_along_axis",
"expand_dims",
Expand Down
112 changes: 111 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools as ft
import logging
import typing
import warnings
Expand All @@ -23,7 +24,6 @@
if TYPE_CHECKING:
from pytensor.tensor import TensorLike


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -908,6 +908,115 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
)


def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])


class BlockDiagonalMatrix(Op):
__props__ = ("sparse", "format")

def __init__(self, sparse=False, format="csr"):
if format not in ("csr", "csc"):
raise ValueError(f"format must be one of: 'csr', 'csc', got {format}")
self.sparse = sparse
self.format = format

def make_node(self, *matrices):
if not matrices:
raise ValueError("no matrices to allocate")
dtype = largest_common_dtype(matrices)
matrices = list(map(pt.as_tensor, matrices))

if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("all data arguments must be matrices")
if self.sparse:
out_type = pytensor.sparse.matrix(self.format, dtype=dtype)
else:
out_type = pytensor.tensor.matrix(dtype=dtype)
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = largest_common_dtype(inputs)
if self.sparse:
output_storage[0][0] = scipy.sparse.block_diag(inputs, self.format, dtype)
else:
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)

def grad(self, inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ptb.ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]

def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]


def block_diagonal(
matrices: typing.Sequence[TensorVariable],
sparse: bool = False,
format: Literal["csr", "csc"] = "csr",
):
"""
Construct a block diagonal matrix from a sequence of input matrices.

Parameters
----------
matrices: sequence of tensors
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Blockwise accepts different number of batch dims and also broadcasts when they have length 1.

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got errors when I tried tensors with different batch dims, but I didn't try broadcasting to dimensions with size 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still see errors? Blockwise should introduce expand dims, so the only failure case would be broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting works, I added a test for it. It was failing when I tried different batch sizes, which doesn't make sense anyway I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean different batch sizes? Blockwise adds expand dims automatically to align the number of batch dims, so that shouldn't be possible?

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This errors:

    # Different batch sizes
    A = np.random.normal(size=(batch_size + 3, 2, 2)).astype(config.floatX)
    B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX)
    result = block_diag(A, B).eval()

with:

E           ValueError: Incompatible Blockwise batch input shapes [(8, 2, 2), (5, 4, 4)]

But I think it's supposed to. What does it even mean to batch those two together?

Copy link
Member

@ricardoV94 ricardoV94 Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's invalid, batch shapes must be broadcastable. 8 and 5 are not broadcastable.

I thought you were saying inputs with different number of dimensions were failing

block diagonal matrix will be formed along the first axis of the matrices.
sparse : bool, optional
If True, the function returns a sparse matrix in the specified format. Default is True.
format: str, optional
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.

Returns
-------
out: tensor or sparse matrix tensor
The block diagonal matrix formed from the input matrices. If `sparse` is True, the output will be a symbolic
sparse matrix in the specified format. Otherwise, a symbolic tensor will be returned.

Examples
--------
Create a block diagonal matrix from two 2x2 matrices:

..code-block:: python

import numpy as np
from pytensor.tensor.slinalg import block_diagonal

matrices = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]])]
matrices = [pt.as_tensor_variable(mat) for mat in matrices]
result = block_diagonal(matrices)

print(result)
>>> Out: array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])

Create a sparse block diagonal matrix from two sparse 2x2 matrices:

..code-block:: python

matrices_sparse = [csr_matrix([[1, 2], [3, 4]]), csr_matrix([[5, 6], [7, 8]])]
result_sparse = block_diagonal(matrices_sparse, sparse=True)

The resulting sparse block diagonal matrix `result_sparse` is in CSR format.
"""
if len(matrices) == 1: # graph optimization
return matrices[0]
return BlockDiagonalMatrix(sparse=sparse, format=format)(*matrices)


__all__ = [
"cholesky",
"solve",
Expand All @@ -918,4 +1027,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diagonal",
]
12 changes: 12 additions & 0 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Solve,
SolveBase,
SolveTriangular,
block_diagonal,
cho_solve,
cholesky,
eigvalsh,
Expand Down Expand Up @@ -661,3 +662,14 @@ def test_solve_discrete_are_grad():
rng=rng,
abs_tol=atol,
)


def test_block_diagonal():
matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])]
result = block_diagonal(matrices)
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(*matrices))

result = block_diagonal(matrices, format="csr", sparse=True)
sp_result = scipy.sparse.block_diag(matrices, format="csr")
assert type(result.eval()) == type(sp_result)
np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray())