Skip to content

Commit 1de23db

Browse files
Test sparse and dense inputs to pytensor.sparse.block_diag
1 parent 62bab9d commit 1de23db

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/sparse/test_basic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3393,11 +3393,14 @@ class TestSharedOptions:
33933393

33943394

33953395
@pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"])
3396-
def test_block_diagonal(format):
3396+
@pytest.mark.parametrize("sparse_input", [True, False], ids=["sparse", "dense"])
3397+
def test_block_diagonal(format, sparse_input):
33973398
from scipy import sparse as sp_sparse
33983399

3399-
A = sp_sparse.csr_matrix([[1, 2], [3, 4]])
3400-
B = sp_sparse.csr_matrix([[5, 6], [7, 8]])
3400+
f_array = sp_sparse.csr_matrix if sparse_input else np.array
3401+
A = f_array([[1, 2], [3, 4]]).astype(config.floatX)
3402+
B = f_array([[5, 6], [7, 8]]).astype(config.floatX)
3403+
34013404
result = block_diag(A, B, format=format, name="X")
34023405
sp_result = sp_sparse.block_diag([A, B], format=format)
34033406

0 commit comments

Comments
 (0)