diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 4481e442f9..73ddadc2a0 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -1,7 +1,7 @@ import jax from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular @jax_funcify.register(Cholesky) @@ -45,3 +45,11 @@ def solve_triangular(A, b): ) return solve_triangular + + +@jax_funcify.register(BlockDiagonal) +def jax_funcify_BlockDiagonalMatrix(op, **kwargs): + def block_diag(*inputs): + return jax.scipy.linalg.block_diag(*inputs) + + return block_diag diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ad8065defd..a5ac0c6348 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -9,7 +9,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import numba_funcify -from pytensor.tensor.slinalg import SolveTriangular +from pytensor.tensor.slinalg import BlockDiagonal, SolveTriangular _PTR = ctypes.POINTER @@ -273,3 +273,25 @@ def solve_triangular(a, b): return res return solve_triangular + + +@numba_funcify.register(BlockDiagonal) +def numba_funcify_BlockDiagonal(op, node, **kwargs): + dtype = node.outputs[0].dtype + + # TODO: Why do we always inline all functions? It doesn't work with starred args, so can't use it in this case. + @numba_basic.numba_njit(inline="never") + def block_diag(*arrs): + shapes = np.array([a.shape for a in arrs], dtype="int") + out_shape = [int(s) for s in np.sum(shapes, axis=0)] + out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype) + + r, c = 0, 0 + for arr, shape in zip(arrs, shapes): + rr, cc = shape + out[r : r + rr, c : c + cc] = arr + r += rr + c += cc + return out + + return block_diag diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 363400416f..96105adc5c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -7,6 +7,7 @@ TODO: Automatic methods for determining best sparse format? """ +from typing import Literal from warnings import warn import numpy as np @@ -47,6 +48,7 @@ trunc, ) from pytensor.tensor.shape import shape, specify_broadcastable +from pytensor.tensor.slinalg import BaseBlockDiagonal, _largest_common_dtype from pytensor.tensor.type import TensorType from pytensor.tensor.type import continuous_dtypes as tensor_continuous_dtypes from pytensor.tensor.type import discrete_dtypes as tensor_discrete_dtypes @@ -60,7 +62,6 @@ sparse_formats = ["csc", "csr"] - """ Types of sparse matrices to use for testing. @@ -183,7 +184,6 @@ def as_sparse_variable(x, name=None, ndim=None, **kwargs): as_sparse = as_sparse_variable - as_sparse_or_tensor_variable = as_symbolic @@ -1800,7 +1800,7 @@ def infer_shape(self, fgraph, node, shapes): return r def __str__(self): - return f"{self.__class__.__name__ }{{axis={self.axis}}}" + return f"{self.__class__.__name__}{{axis={self.axis}}}" def sp_sum(x, axis=None, sparse_grad=False): @@ -2775,19 +2775,14 @@ def comparison(self, x, y): greater_equal_s_d = GreaterEqualSD() - eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d) - neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d) - lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d) - gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d) - le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d) ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d) @@ -2992,7 +2987,7 @@ def __str__(self): l = [] if self.inplace: l.append("inplace") - return f"{self.__class__.__name__ }{{{', '.join(l)}}}" + return f"{self.__class__.__name__}{{{', '.join(l)}}}" def make_node(self, x): """ @@ -3291,6 +3286,7 @@ class TrueDot(Op): # Simplify code by splitting into DotSS and DotSD. __props__ = () + # The grad_preserves_dense attribute doesn't change the # execution behavior. To let the optimizer merge nodes with # different values of this attribute we shouldn't compare it @@ -4260,3 +4256,85 @@ def grad(self, inputs, grads): construct_sparse_from_list = ConstructSparseFromList() + + +class SparseBlockDiagonal(BaseBlockDiagonal): + __props__ = ( + "n_inputs", + "format", + ) + + def __init__(self, n_inputs: int, format: Literal["csc", "csr"] = "csc"): + super().__init__(n_inputs) + self.format = format + + def make_node(self, *matrices): + matrices = self._validate_and_prepare_inputs( + matrices, as_sparse_or_tensor_variable + ) + dtype = _largest_common_dtype(matrices) + out_type = matrix(format=self.format, dtype=dtype) + + return Apply(self, matrices, [out_type]) + + def perform(self, node, inputs, output_storage, params=None): + dtype = node.outputs[0].type.dtype + output_storage[0][0] = scipy.sparse.block_diag( + inputs, format=self.format + ).astype(dtype) + + +def block_diag(*matrices: TensorVariable, format: Literal["csc", "csr"] = "csc"): + r""" + Construct a block diagonal matrix from a sequence of input matrices. + + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] + + Parameters + ---------- + A, B, C ... : tensors + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all + inputs should have at least 2 dimensins. + + Note that the input matrices need not be sparse themselves, and will be automatically converted to the + requested format if they are not. + + format: str, optional + The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False. + + Returns + ------- + out: sparse matrix tensor + Symbolic sparse matrix in the specified format. + + Examples + -------- + Create a sparse block diagonal matrix from two sparse 2x2 matrices: + + ..code-block:: python + import numpy as np + from pytensor.sparse import block_diag + from scipy.sparse import csr_matrix + + A = csr_matrix([[1, 2], [3, 4]]) + B = csr_matrix([[5, 6], [7, 8]]) + result_sparse = block_diag(A, B, format='csr', name='X') + + print(result_sparse) + >>> SparseVariable{csr,int32} + + print(result_sparse.toarray().eval()) + >>> array([[1, 2, 0, 0], + >>> [3, 4, 0, 0], + >>> [0, 0, 5, 6], + >>> [0, 0, 7, 8]]) + """ + if len(matrices) == 1: + return matrices + + _sparse_block_diagonal = SparseBlockDiagonal(n_inputs=len(matrices), format=format) + return _sparse_block_diagonal(*matrices) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 207fd4909a..4b043a6471 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4269,6 +4269,25 @@ def take_along_axis(arr, indices, axis=0): return arr[_make_along_axis_idx(arr.shape, indices, axis)] +def ix_(*args): + """ + PyTensor np.ix_ analog + + See numpy.lib.index_tricks.ix_ for reference + """ + out = [] + nd = len(args) + for k, new in enumerate(args): + if new is None: + out.append(slice(None)) + new = as_tensor(new) + if new.ndim != 1: + raise ValueError("Cross index must be 1 dimensional") + new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1)) + out.append(new) + return tuple(out) + + __all__ = [ "take_along_axis", "expand_dims", diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f96dec5a35..aae80fb578 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,6 +1,7 @@ import logging import typing import warnings +from functools import reduce from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np @@ -23,7 +24,6 @@ if TYPE_CHECKING: from pytensor.tensor import TensorLike - logger = logging.getLogger(__name__) @@ -908,6 +908,107 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: ) +def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: + return reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) + + +class BaseBlockDiagonal(Op): + __props__ = ("n_inputs",) + + def __init__(self, n_inputs): + input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) + self.gufunc_signature = f"{input_sig}->(m,n)" + + if n_inputs == 0: + raise ValueError("n_inputs must be greater than 0") + self.n_inputs = n_inputs + + def grad(self, inputs, gout): + shapes = pt.stack([i.shape for i in inputs]) + index_end = shapes.cumsum(0) + index_begin = index_end - shapes + slices = [ + ptb.ix_( + pt.arange(index_begin[i, 0], index_end[i, 0]), + pt.arange(index_begin[i, 1], index_end[i, 1]), + ) + for i in range(len(inputs)) + ] + return [gout[0][slc] for slc in slices] + + def infer_shape(self, fgraph, nodes, shapes): + first, second = zip(*shapes) + return [(pt.add(*first), pt.add(*second))] + + def _validate_and_prepare_inputs(self, matrices, as_tensor_func): + if len(matrices) != self.n_inputs: + raise ValueError( + f"Expected {self.n_inputs} matri{'ces' if self.n_inputs > 1 else 'x'}, got {len(matrices)}" + ) + matrices = list(map(as_tensor_func, matrices)) + if any(mat.type.ndim != 2 for mat in matrices): + raise TypeError("All inputs must have dimension 2") + return matrices + + +class BlockDiagonal(BaseBlockDiagonal): + __props__ = ("n_inputs",) + + def make_node(self, *matrices): + matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor) + dtype = _largest_common_dtype(matrices) + out_type = pytensor.tensor.matrix(dtype=dtype) + return Apply(self, matrices, [out_type]) + + def perform(self, node, inputs, output_storage, params=None): + dtype = node.outputs[0].type.dtype + output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) + + +def block_diag(*matrices: TensorVariable): + """ + Construct a block diagonal matrix from a sequence of input tensors. + + Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: + + [[A, 0, 0], + [0, B, 0], + [0, 0, C]] + + Parameters + ---------- + A, B, C ... : tensors + Input tensors to form the block diagonal matrix. last two dimensions of the inputs will be used, and all + inputs should have at least 2 dimensins. + + Returns + ------- + out: tensor + The block diagonal matrix formed from the input matrices. + + Examples + -------- + Create a block diagonal matrix from two 2x2 matrices: + + ..code-block:: python + + import numpy as np + from pytensor.tensor.linalg import block_diag + + A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) + B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]])) + + result = block_diagonal(A, B, name='X') + print(result.eval()) + >>> Out: array([[1, 2, 0, 0], + >>> [3, 4, 0, 0], + >>> [0, 0, 5, 6], + >>> [0, 0, 7, 8]]) + """ + _block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices))) + return _block_diagonal_matrix(*matrices) + + __all__ = [ "cholesky", "solve", @@ -918,4 +1019,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: "solve_continuous_lyapunov", "solve_discrete_are", "solve_triangular", + "block_diag", ] diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 4ae9531f9b..53e154facc 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -129,3 +129,37 @@ def test_jax_SolveTriangular(trans, lower, check_finite): np.arange(10).astype(config.floatX), ], ) + + +def test_jax_block_diag(): + A = matrix("A") + B = matrix("B") + C = matrix("C") + D = matrix("D") + + out = pt_slinalg.block_diag(A, B, C, D) + out_fg = FunctionGraph([A, B, C, D], [out]) + + compare_jax_and_py( + out_fg, + [ + np.random.normal(size=(5, 5)).astype(config.floatX), + np.random.normal(size=(3, 3)).astype(config.floatX), + np.random.normal(size=(2, 2)).astype(config.floatX), + np.random.normal(size=(4, 4)).astype(config.floatX), + ], + ) + + +def test_jax_block_diag_blockwise(): + A = pt.tensor3("A") + B = pt.tensor3("B") + out = pt_slinalg.block_diag(A, B) + out_fg = FunctionGraph([A, B], [out]) + compare_jax_and_py( + out_fg, + [ + np.random.normal(size=(5, 5, 5)).astype(config.floatX), + np.random.normal(size=(5, 3, 3)).astype(config.floatX), + ], + ) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 75e016f1e0..33ec1a529c 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -6,11 +6,11 @@ import pytensor import pytensor.tensor as pt from pytensor import config +from tests.link.numba.test_basic import compare_numba_and_py numba = pytest.importorskip("numba") - ATOL = 0 if config.floatX.endswith("64") else 1e-6 RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 rng = np.random.default_rng(42849) @@ -102,3 +102,18 @@ def test_solve_triangular_raises_on_nan_inf(value): ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") ): f(A_tri, b) + + +def test_block_diag(): + A = pt.matrix("A") + B = pt.matrix("B") + C = pt.matrix("C") + D = pt.matrix("D") + X = pt.linalg.block_diag(A, B, C, D) + + A_val = np.random.normal(size=(5, 5)) + B_val = np.random.normal(size=(3, 3)) + C_val = np.random.normal(size=(2, 2)) + D_val = np.random.normal(size=(4, 4)) + out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) + compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val]) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 16fd5fef04..590c76e008 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -51,6 +51,7 @@ add_s_s_data, as_sparse_or_tensor_variable, as_sparse_variable, + block_diag, cast, clean, construct_sparse_from_list, @@ -3389,3 +3390,21 @@ def _helper(x, y): ) class TestSharedOptions: pass + + +@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"]) +@pytest.mark.parametrize("sparse_input", [True, False], ids=["sparse", "dense"]) +def test_block_diagonal(format, sparse_input): + from scipy import sparse as sp_sparse + + f_array = sp_sparse.csr_matrix if sparse_input else np.array + A = f_array([[1, 2], [3, 4]]).astype(config.floatX) + B = f_array([[5, 6], [7, 8]]).astype(config.floatX) + + result = block_diag(A, B, format=format) + assert result.owner.op._props_dict() == {"n_inputs": 2, "format": format} + + sp_result = sp_sparse.block_diag([A, B], format=format) + + assert isinstance(result.eval(), type(sp_result)) + np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 504d848140..a2cc3c52e8 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -15,6 +15,7 @@ Solve, SolveBase, SolveTriangular, + block_diag, cho_solve, cholesky, eigvalsh, @@ -661,3 +662,40 @@ def test_solve_discrete_are_grad(): rng=rng, abs_tol=atol, ) + + +def test_block_diagonal(): + A = np.array([[1.0, 2.0], [3.0, 4.0]]) + B = np.array([[5.0, 6.0], [7.0, 8.0]]) + result = block_diag(A, B) + assert result.owner.op.core_op._props_dict() == {"n_inputs": 2} + + np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) + + +def test_block_diagonal_grad(): + A = np.array([[1.0, 2.0], [3.0, 4.0]]) + B = np.array([[5.0, 6.0], [7.0, 8.0]]) + + utt.verify_grad(block_diag, pt=[A, B], rng=np.random.default_rng()) + + +def test_block_diagonal_blockwise(): + batch_size = 5 + A = np.random.normal(size=(batch_size, 2, 2)).astype(config.floatX) + B = np.random.normal(size=(batch_size, 4, 4)).astype(config.floatX) + result = block_diag(A, B).eval() + assert result.shape == (batch_size, 6, 6) + for i in range(batch_size): + np.testing.assert_allclose( + result[i], + scipy.linalg.block_diag(A[i], B[i]), + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) + + # Test broadcasting + A = np.random.normal(size=(10, batch_size, 2, 2)).astype(config.floatX) + B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) + result = block_diag(A, B).eval() + assert result.shape == (10, batch_size, 6, 6)