Skip to content

Add linalg.block_diag and sparse equivalent #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular


@jax_funcify.register(Cholesky)
Expand Down Expand Up @@ -45,3 +45,11 @@ def solve_triangular(A, b):
)

return solve_triangular


@jax_funcify.register(BlockDiagonal)
def jax_funcify_BlockDiagonalMatrix(op, **kwargs):
def block_diag(*inputs):
return jax.scipy.linalg.block_diag(*inputs)

return block_diag
24 changes: 23 additions & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify
from pytensor.tensor.slinalg import SolveTriangular
from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular


_PTR = ctypes.POINTER
Expand Down Expand Up @@ -273,3 +273,25 @@ def solve_triangular(a, b):
return res

return solve_triangular


@numba_funcify.register(BlockDiagonal)
def numba_funcify_BlockDiagonal(op, node, **kwargs):
dtype = node.outputs[0].dtype

# TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case.
@numba_basic.numba_njit(inline="never")
def block_diag(*arrs):
shapes = np.array([a.shape for a in arrs], dtype="int")
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)

r, c = 0, 0
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
c += cc
return out

return block_diag
96 changes: 87 additions & 9 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TODO: Automatic methods for determining best sparse format?

"""
from typing import Literal
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -47,6 +48,7 @@
trunc,
)
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype
from pytensor.tensor.type import TensorType
from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes
Expand All @@ -60,7 +62,6 @@

sparse_formats = ["csc", "csr"]


"""
Types of sparse matrices to use for testing.

Expand Down Expand Up @@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs):

as_sparse = as_sparse_variable


as_sparse_or_tensor_variable = as_symbolic


Expand Down Expand Up @@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes):
return r

def __str__(self):
return f"{self.__class__.__name__ }{{axis={self.axis}}}"
return f"{self.__class__.__name__}{{axis={self.axis}}}"


def sp_sum(x, axis=None, sparse_grad=False):
Expand Down Expand Up @@ -2775,19 +2775,14 @@ def comparison(self, x, y):

greater_equal_s_d = GreaterEqualSD()


eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)


neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)


lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)


gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)


le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)

ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
Expand Down Expand Up @@ -2992,7 +2987,7 @@ def __str__(self):
l = []
if self.inplace:
l.append("inplace")
return f"{self.__class__.__name__ }{{{', '.join(l)}}}"
return f"{self.__class__.__name__}{{{', '.join(l)}}}"

def make_node(self, x):
"""
Expand Down Expand Up @@ -3291,6 +3286,7 @@ class TrueDot(Op):
# Simplify code by splitting into DotSS and DotSD.

__props__ = ()

# The grad_preserves_dense attribute doesn't change the
# execution behavior. To let the optimizer merge nodes with
# different values of this attribute we shouldn't compare it
Expand Down Expand Up @@ -4260,3 +4256,85 @@ def grad(self, inputs, grads):


construct_sparse_from_list = ConstructSparseFromList()


class SparseBlockDiagonal(BaseBlockDiagonal):
__props__ = (
"n_inputs",
"format",
)

def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"):
super().__init__(n_inputs)
self.format = format

def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(
matrices, as_sparse_or_tensor_variable
)
dtype = _largest_common_dtype(matrices)
out_type = matrix(format=self.format, dtype=dtype)

return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.sparse.block_diag(
inputs, format=self.format
).astype(dtype)


def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"):
r"""
Construct a block diagonal matrix from a sequence of input matrices.

Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:

[[A, 0, 0],
[0, B, 0],
[0, 0, C]]

Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.

Note that the input matrices need not be sparse themselves, and will be automatically converted to the
requested format if they are not.

format: str, optional
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.

Returns
-------
out: sparse matrix tensor
Symbolic sparse matrix in the specified format.

Examples
--------
Create a sparse block diagonal matrix from two sparse 2x2 matrices:

..code-block:: python
import numpy as np
from pytensor.sparse import block_diag
from scipy.sparse import csr_matrix

A = csr_matrix([[1, 2], [3, 4]])
B = csr_matrix([[5, 6], [7, 8]])
result_sparse = block_diag(A, B, format='csr', name='X')

print(result_sparse)
>>> SparseVariable{csr,int32}

print(result_sparse.toarray().eval())
>>> array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
if len(matrices) == 1:
return matrices

_sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format)
return _sparse_block_diagonal(*matrices)
19 changes: 19 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4269,6 +4269,25 @@ def take_along_axis(arr, indices, axis=0):
return arr[_make_along_axis_idx(arr.shape, indices, axis)]


def ix_(*args):
"""
PyTensor np.ix_ analog

See numpy.lib.index_tricks.ix_ for reference
"""
out = []
nd = len(args)
for k, new in enumerate(args):
if new is None:
out.append(slice(None))
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)


__all__ = [
"take_along_axis",
"expand_dims",
Expand Down
104 changes: 103 additions & 1 deletion pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import typing
import warnings
from functools import reduce
from typing import TYPE_CHECKING, Literal, Optional, Union

import numpy as np
Expand All @@ -23,7 +24,6 @@
if TYPE_CHECKING:
from pytensor.tensor import TensorLike


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -908,6 +908,107 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
)


def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:
return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors])


class BaseBlockDiagonal(Op):
__props__ = ("n_inputs",)

def __init__(self, n_inputs):
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)])
self.gufunc_signature = f"{input_sig}->(m,n)"

if n_inputs == 0:
raise ValueError("n_inputs must be greater than 0")
self.n_inputs = n_inputs

def grad(self, inputs, gout):
shapes = pt.stack([i.shape for i in inputs])
index_end = shapes.cumsum(0)
index_begin = index_end - shapes
slices = [
ptb.ix_(
pt.arange(index_begin[i, 0], index_end[i, 0]),
pt.arange(index_begin[i, 1], index_end[i, 1]),
)
for i in range(len(inputs))
]
return [gout[0][slc] for slc in slices]

def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes)
return [(pt.add(*first), pt.add(*second))]

def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
if len(matrices) != self.n_inputs:
raise ValueError(
f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}"
)
matrices = list(map(as_tensor_func, matrices))
if any(mat.type.ndim != 2 for mat in matrices):
raise TypeError("All inputs must have dimension 2")
return matrices


class BlockDiagonal(BaseBlockDiagonal):
__props__ = ("n_inputs",)

def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
dtype = _largest_common_dtype(matrices)
out_type = pytensor.tensor.matrix(dtype=dtype)
return Apply(self, matrices, [out_type])

def perform(self, node, inputs, output_storage, params=None):
dtype = node.outputs[0].type.dtype
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype)


def block_diag(*matrices: TensorVariable):
"""
Construct a block diagonal matrix from a sequence of input tensors.

Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal:

[[A, 0, 0],
[0, B, 0],
[0, 0, C]]

Parameters
----------
A, B, C ... : tensors
Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all
inputs should have at least 2 dimensins.

Returns
-------
out: tensor
The block diagonal matrix formed from the input matrices.

Examples
--------
Create a block diagonal matrix from two 2x2 matrices:

..code-block:: python

import numpy as np
from pytensor.tensor.linalg import block_diag

A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]]))

result = block_diagonal(A, B, name='X')
print(result.eval())
>>> Out: array([[1, 2, 0, 0],
>>> [3, 4, 0, 0],
>>> [0, 0, 5, 6],
>>> [0, 0, 7, 8]])
"""
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
return _block_diagonal_matrix(*matrices)


__all__ = [
"cholesky",
"solve",
Expand All @@ -918,4 +1019,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
"solve_continuous_lyapunov",
"solve_discrete_are",
"solve_triangular",
"block_diag",
]
Loading