Skip to content

Add linalg.block_diag and sparse equivalent #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import jax

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
from pytensor.tensor.slinalg import (
BlockDiagonalMatrix,
Cholesky,
Solve,
SolveTriangular,
)


@jax_funcify.register(Cholesky)
Expand Down Expand Up @@ -45,3 +50,11 @@ def solve_triangular(A, b):
)

return solve_triangular


@jax_funcify.register(BlockDiagonalMatrix)
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
def block_diag(*inputs):
return jax.scipy.linalg.block_diag(*inputs)

return block_diag
24 changes: 23 additions & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonalMatrix, SolveTriangular


_PTR = ctypes.POINTER
Expand Down Expand Up @@ -273,3 +273,25 @@ def solve_triangular(a, b):
return res

return solve_triangular


@numba_funcify.register(BlockDiagonalMatrix)
def numba_funcify_BlockDiagonalMatrix(op, node, **kwargs):
dtype = node.outputs[0].dtype

# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit(inline="never")
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype=dtype)
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
out = np.zeros((out_shape[0], out_shape[1]))

r, c = 0, 0
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
c += cc
return out

return block_diag
96 changes: 87 additions & 9 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TODO: Automatic methods for determining best sparse format?

"""
from typing import Literal
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -47,6 +48,7 @@
trunc,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.slinalg import BaseBlockDiagonal, largest_common_dtype
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
Expand All @@ -60,7 +62,6 @@

sparse_formats = ["csc", "csr"]


"""
Types of sparse matrices to use for testing.

Expand Down Expand Up @@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):

as_sparse = as_sparse_variable


as_sparse_or_tensor_variable = as_symbolic


Expand Down Expand Up @@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
return r

def __str__(self):
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{axis={self.axis}}}"


def sp_sum(x, axis=None, sparse_grad=False):
Expand Down Expand Up @@ -2775,19 +2775,14 @@ def comparison(self, x, y):

greater_equal_s_d = GreaterEqualSD()


eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)


neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)


lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)


gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)


le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)

ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
Expand Down Expand Up @@ -2992,7 +2987,7 @@ def __str__(self):
l = []
if self.inplace:
l.append("inplace")
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
return f"{self.__class__.__name__}{{{', '.join(l)}}}"

def make_node(self, x):
"""
Expand Down Expand Up @@ -3291,6 +3286,7 @@ class TrueDot(Op):
# Simplify code by splitting into DotSS and DotSD.

__props__ = ()

# The grad_preserves_dense attribute doesn't change the
# execution behavior. To let the optimizer merge nodes with
# different values of this attribute we shouldn't compare it
Expand Down Expand Up @@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):


construct_sparse_from_list = ConstructSparseFromList()


class SparseBlockDiagonalMatrix(BaseBlockDiagonal):
def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None):
if not matrices:
raise ValueError("no matrices to allocate")
dtype = largest_common_dtype(matrices)
matrices = list(map(as_sparse_or_tensor_variable, matrices))

if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("all data arguments must be matrices")

out_type = matrix(format=format, dtype=dtype, name=name)
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
format = node.outputs[0].type.format
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype(
dtype
)


_sparse_block_diagonal = SparseBlockDiagonalMatrix()


def block_diag(
*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc", name=None
):
r"""
Construct a block diagonal matrix from a sequence of input matrices.

Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:

[[A, 0, 0],
[0, B, 0],
[0, 0, C]]

Parameters
----------
A, B, C ... : tensors or array-like
Inputs to form the block diagonal matrix. Each input should have the same number of dimensions,
and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix.

Note that the input matrices need not be sparse themselves, and will be automatically converted to the
requested format if they are not.
format: str, optional
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
name: str, optional
Name of the output tensor.

Returns
-------
out: sparse matrix tensor
Symbolic sparse matrix in the specified format.

Examples
--------
Create a sparse block diagonal matrix from two sparse 2x2 matrices:

..code-block:: python
import numpy as np
from pytensor.sparse import block_diag
from scipy.sparse import csr_matrix

A = csr_matrix([[1, 2], [3, 4]])
B = csr_matrix([[5, 6], [7, 8]])
result_sparse = block_diag(A, B, format='csr', name='X')

print(result_sparse)
>>> SparseVariable{csr,int32}

print(result_sparse.toarray().eval())
>>> array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
if len(matrices) == 1:
return matrices

return _sparse_block_diagonal(*matrices, format=format, name=name)
19 changes: 19 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4269,6 +4269,25 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr.shape, indices, axis)]


def ix_(*args):
"""
PyTensor np.ix_ analog

See numpy.lib.index_tricks.ix_ for reference
"""
out = []
nd = len(args)
for k, new in enumerate(args):
if new is None:
out.append(slice(None))
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)


__all__ = [
"take_along_axis",
"expand_dims",
Expand Down
98 changes: 97 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools as ft
import logging
import typing
import warnings
Expand All @@ -23,7 +24,6 @@
if TYPE_CHECKING:
from pytensor.tensor import TensorLike


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -908,6 +908,101 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
)


def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])


def block_diag_grad(inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ptb.ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]


class BaseBlockDiagonal(Op):
def grad(self, inputs, gout):
return block_diag_grad(inputs, gout)

def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]


class BlockDiagonalMatrix(BaseBlockDiagonal):
def make_node(self, *matrices, name=None):
if not matrices:
raise ValueError("no matrices to allocate")
dtype = largest_common_dtype(matrices)
matrices = list(map(pt.as_tensor, matrices))

if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("all data arguments must be matrices")

out_type = pytensor.tensor.matrix(dtype=dtype, name=name)
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)


_block_diagonal_matrix = BlockDiagonalMatrix()


def block_diag(*matrices: TensorVariable, name=None):
"""
Construct a block diagonal matrix from a sequence of input tensors.

Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:

[[A, 0, 0],
[0, B, 0],
[0, 0, C]]

Parameters
----------
A, B, C ... : tensors
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct. Blockwise accepts different number of batch dims and also broadcasts when they have length 1.

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got errors when I tried tensors with different batch dims, but I didn't try broadcasting to dimensions with size 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still see errors? Blockwise should introduce expand dims, so the only failure case would be broadcasting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcasting works, I added a test for it. It was failing when I tried different batch sizes, which doesn't make sense anyway I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean different batch sizes? Blockwise adds expand dims automatically to align the number of batch dims, so that shouldn't be possible?

Copy link
Member Author

@jessegrabowski jessegrabowski Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This errors:

    # Different batch sizes
    A = np.random.normal(size=(batch_size + 3, 2, 2)).astype(config.floatX)
    B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX)
    result = block_diag(A, B).eval()

with:

E           ValueError: Incompatible Blockwise batch input shapes [(8, 2, 2), (5, 4, 4)]

But I think it's supposed to. What does it even mean to batch those two together?

Copy link
Member

@ricardoV94 ricardoV94 Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's invalid, batch shapes must be broadcastable. 8 and 5 are not broadcastable.

I thought you were saying inputs with different number of dimensions were failing

block diagonal matrix will be formed using the right-most two dimensions of each input matrix.
name: str, optional
Name of the block diagonal matrix.

Returns
-------
out: tensor
The block diagonal matrix formed from the input matrices.

Examples
--------
Create a block diagonal matrix from two 2x2 matrices:

..code-block:: python

import numpy as np
from pytensor.tensor.slinalg import block_diag

A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))

result = block_diagonal(A, B, name='X')
print(result.eval())
>>> Out: array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
if len(matrices) == 1: # graph optimization
return matrices
return _block_diagonal_matrix(*matrices, name=name)


__all__ = [
"cholesky",
"solve",
Expand All @@ -918,4 +1013,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
]
20 changes: 20 additions & 0 deletions tests/link/jax/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,23 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
np.arange(10).astype(config.floatX),
],
)


def test_jax_block_diag():
A = matrix("A")
B = matrix("B")
C = matrix("C")
D = matrix("D")

out = pt_slinalg.block_diag(A, B, C, D)

out_fg = FunctionGraph([A, B, C, D], [out])
compare_jax_and_py(
out_fg,
[
np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX),
np.random.normal(size=(2, 2)).astype(config.floatX),
np.random.normal(size=(4, 4)).astype(config.floatX),
],
)
Loading