Skip to content

Commit 9857f0f

Browse files
committed
fixed typecasting for tests
1 parent 81d4748 commit 9857f0f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/tensor/rewriting/test_linalg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def test_diag_blockdiag_rewrite():
583583
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
584584

585585
# Value Test
586-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
586+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
587587
bd_output_test = scipy.linalg.block_diag(
588588
*[sub_matrices_test[i] for i in range(n_matrices)]
589589
)
@@ -610,7 +610,7 @@ def test_det_blockdiag_rewrite():
610610
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
611611

612612
# Value Test
613-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
613+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
614614
bd_output_test = scipy.linalg.block_diag(
615615
*[sub_matrices_test[i] for i in range(n_matrices)]
616616
)
@@ -639,7 +639,7 @@ def test_slogdet_blockdiag_rewrite():
639639
assert not any(isinstance(node.op, BlockDiagonal) for node in nodes)
640640

641641
# Value Test
642-
sub_matrices_test = np.random.rand(n_matrices, *matrix_size)
642+
sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX)
643643
bd_output_test = scipy.linalg.block_diag(
644644
*[sub_matrices_test[i] for i in range(n_matrices)]
645645
)

0 commit comments

Comments
 (0)