Skip to content

Commit 62bab9d

Browse files
Use as_sparse_or_tensor_variable in SparseBlockDiagonalMatrix to allow sparse matrix inputs to pytensor.sparse.block_diag
1 parent fd26b74 commit 62bab9d

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

pytensor/sparse/basic.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4263,7 +4263,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None)
42634263
if not matrices:
42644264
raise ValueError("no matrices to allocate")
42654265
dtype = largest_common_dtype(matrices)
4266-
matrices = list(map(pytensor.tensor.as_tensor, matrices))
4266+
matrices = list(map(as_sparse_or_tensor_variable, matrices))
42674267

42684268
if any(mat.type.ndim != 2 for mat in matrices):
42694269
raise TypeError("all data arguments must be matrices")
@@ -4273,7 +4273,7 @@ def make_node(self, *matrices, format: Literal["csc", "csr"] = "csc", name=None)
42734273

42744274
def perform(self, node, inputs, output_storage, params=None):
42754275
format = node.outputs[0].type.format
4276-
dtype = largest_common_dtype(inputs)
4276+
dtype = node.outputs[0].type.dtype
42774277
output_storage[0][0] = scipy.sparse.block_diag(inputs, format=format).astype(
42784278
dtype
42794279
)
@@ -4296,9 +4296,12 @@ def block_diag(
42964296
42974297
Parameters
42984298
----------
4299-
A, B, C ... : tensors
4300-
Input sparse matrices to form the block diagonal matrix. Each matrix should have the same number of dimensions,
4299+
A, B, C ... : tensors or array-like
4300+
Inputs to form the block diagonal matrix. Each input should have the same number of dimensions,
43014301
and the block diagonal matrix will be formed using the right-most two dimensions of each input matrix.
4302+
4303+
Note that the input matrices need not be sparse themselves, and will be automatically converted to the
4304+
requested format if they are not.
43024305
format: str, optional
43034306
The format of the output sparse matrix. One of 'csr' or 'csc'. Default is 'csr'. Ignored if sparse=False.
43044307
name: str, optional
@@ -4321,9 +4324,15 @@ def block_diag(
43214324
A = csr_matrix([[1, 2], [3, 4]])
43224325
B = csr_matrix([[5, 6], [7, 8]])
43234326
result_sparse = block_diag(A, B, format='csr', name='X')
4324-
print(result_sparse.eval())
43254327
4326-
The resulting sparse block diagonal matrix `result_sparse` is in CSR format.
4328+
print(result_sparse)
4329+
>>> SparseVariable{csr,int32}
4330+
4331+
print(result_sparse.toarray().eval())
4332+
>>> array([[1, 2, 0, 0],
4333+
>>> [3, 4, 0, 0],
4334+
>>> [0, 0, 5, 6],
4335+
>>> [0, 0, 7, 8]])
43274336
"""
43284337
if len(matrices) == 1:
43294338
return matrices

tests/sparse/test_basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3394,11 +3394,12 @@ class TestSharedOptions:
33943394

33953395
@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"])
33963396
def test_block_diagonal(format):
3397-
from scipy.sparse import block_diag as scipy_block_diag
3397+
from scipy import sparse as sp_sparse
33983398

3399-
matrices = [np.array([[1.0, 2.0], [3.0, 4.0]]), np.array([[5.0, 6.0], [7.0, 8.0]])]
3400-
result = block_diag(*matrices, format=format, name="X")
3401-
sp_result = scipy_block_diag(matrices, format=format)
3399+
A = sp_sparse.csr_matrix([[1, 2], [3, 4]])
3400+
B = sp_sparse.csr_matrix([[5, 6], [7, 8]])
3401+
result = block_diag(A, B, format=format, name="X")
3402+
sp_result = sp_sparse.block_diag([A, B], format=format)
34023403

34033404
assert isinstance(result.eval(), type(sp_result))
34043405
np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray())

0 commit comments

Comments
 (0)