-
Notifications
You must be signed in to change notification settings - Fork 132
Add linalg.block_diag
and sparse equivalent
#576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
8e875d7
77b733d
a1dba8e
d809f1c
fd26b74
62bab9d
1de23db
382c50b
491111b
bb2bd36
26bf96d
dd70db9
d32bf9f
747ed1d
2daca2b
a9893b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import functools as ft | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import logging | ||
import typing | ||
import warnings | ||
|
@@ -23,7 +24,6 @@ | |
if TYPE_CHECKING: | ||
from pytensor.tensor import TensorLike | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -908,6 +908,102 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: | |
) | ||
|
||
|
||
def largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype: | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ft.reduce(lambda l, r: np.promote_types(l, r), [x.dtype for x in tensors]) | ||
|
||
|
||
class BaseBlockDiagonal(Op): | ||
__props__ = ("gufunc_signature",) | ||
|
||
def __init__(self, n_inputs): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) | ||
self.gufunc_signature = f"{input_sig}->(m,n)" | ||
|
||
def grad(self, inputs, gout): | ||
shapes = pt.stack([i.shape for i in inputs]) | ||
index_end = shapes.cumsum(0) | ||
index_begin = index_end - shapes | ||
slices = [ | ||
ptb.ix_( | ||
pt.arange(index_begin[i, 0], index_end[i, 0]), | ||
pt.arange(index_begin[i, 1], index_end[i, 1]), | ||
) | ||
for i in range(len(inputs)) | ||
] | ||
return [gout[0][slc] for slc in slices] | ||
|
||
def infer_shape(self, fgraph, nodes, shapes): | ||
first, second = zip(*shapes) | ||
return [(pt.add(*first), pt.add(*second))] | ||
|
||
|
||
class BlockDiagonalMatrix(BaseBlockDiagonal): | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def make_node(self, *matrices): | ||
if not matrices: | ||
raise ValueError("no matrices to allocate") | ||
dtype = largest_common_dtype(matrices) | ||
matrices = list(map(pt.as_tensor, matrices)) | ||
|
||
if any(mat.type.ndim != 2 for mat in matrices): | ||
raise TypeError("all data arguments must be matrices") | ||
|
||
out_type = pytensor.tensor.matrix(dtype=dtype) | ||
return Apply(self, matrices, [out_type]) | ||
|
||
def perform(self, node, inputs, output_storage, params=None): | ||
dtype = node.outputs[0].type.dtype | ||
output_storage[0][0] = scipy.linalg.block_diag(*inputs).astype(dtype) | ||
|
||
|
||
def block_diag(*matrices: TensorVariable, name=None): | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Construct a block diagonal matrix from a sequence of input tensors. | ||
|
||
Given the inputs `A`, `B` and `C`, the output will have these arrays arranged on the diagonal: | ||
|
||
[[A, 0, 0], | ||
[0, B, 0], | ||
[0, 0, C]] | ||
|
||
Parameters | ||
---------- | ||
A, B, C ... : tensors | ||
Input matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions, and the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not correct. Blockwise accepts different number of batch dims and also broadcasts when they have length 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got errors when I tried tensors with different batch dims, but I didn't try broadcasting to dimensions with size 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you still see errors? Blockwise should introduce expand dims, so the only failure case would be broadcasting? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Broadcasting works, I added a test for it. It was failing when I tried different batch sizes, which doesn't make sense anyway I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean different batch sizes? Blockwise adds expand dims automatically to align the number of batch dims, so that shouldn't be possible? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This errors:
with:
But I think it's supposed to. What does it even mean to batch those two together? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's invalid, batch shapes must be broadcastable. 8 and 5 are not broadcastable. I thought you were saying inputs with different number of dimensions were failing |
||
block diagonal matrix will be formed using the right-most two dimensions of each input matrix. | ||
name: str, optional | ||
Name of the block diagonal matrix. | ||
|
||
Returns | ||
------- | ||
out: tensor | ||
The block diagonal matrix formed from the input matrices. | ||
|
||
Examples | ||
-------- | ||
Create a block diagonal matrix from two 2x2 matrices: | ||
|
||
..code-block:: python | ||
|
||
import numpy as np | ||
from pytensor.tensor.slinalg import block_diag | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) | ||
B = pt.as_tensor_variable(np.array([[5, 6], [7, 8]])) | ||
|
||
result = block_diagonal(A, B, name='X') | ||
print(result.eval()) | ||
>>> Out: array([[1, 2, 0, 0], | ||
>>> [3, 4, 0, 0], | ||
>>> [0, 0, 5, 6], | ||
>>> [0, 0, 7, 8]]) | ||
""" | ||
if len(matrices) == 1: # graph optimization | ||
return matrices | ||
|
||
_block_diagonal_matrix = Blockwise(BlockDiagonalMatrix(n_inputs=len(matrices))) | ||
return _block_diagonal_matrix(*matrices) | ||
|
||
|
||
__all__ = [ | ||
"cholesky", | ||
"solve", | ||
|
@@ -918,4 +1014,5 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable: | |
"solve_continuous_lyapunov", | ||
"solve_discrete_are", | ||
"solve_triangular", | ||
"block_diag", | ||
] |
Uh oh!
There was an error while loading. Please reload this page.