Skip to content

Commit fa7d9da

Browse files
committed
fixed typecasting for tests
1 parent 48f6527 commit fa7d9da

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def test_diag_blockdiag_rewrite():
677677
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
678678

679679
# Value Test
680-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
680+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
681681
bd_output_test = scipy.linalg.block_diag(
682682
*[sub_matrices_test[i] for i in range(n_matrices)]
683683
)
@@ -704,7 +704,7 @@ def test_det_blockdiag_rewrite():
704704
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
705705

706706
# Value Test
707-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
707+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
708708
bd_output_test = scipy.linalg.block_diag(
709709
*[sub_matrices_test[i] for i in range(n_matrices)]
710710
)
@@ -733,7 +733,7 @@ def test_slogdet_blockdiag_rewrite():
733733
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
734734

735735
# Value Test
736-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
736+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
737737
bd_output_test = scipy.linalg.block_diag(
738738
*[sub_matrices_test[i] for i in range(n_matrices)]
739739
)

0 commit comments

Comments
 (0)