Skip to content

Commit 6d93813

Browse files
Cast output of cholesky to input dtype
1 parent dc2ce68 commit 6d93813

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/tensor/slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def make_node(self, x):
5656
def perform(self, node, inputs, outputs):
5757
(x,) = inputs
5858
(z,) = outputs
59-
59+
input_dtype = x.dtype
6060
try:
6161
if x.flags["C_CONTIGUOUS"] and self.overwrite_a:
6262
# Inputs to the LAPACK functions need to be exactly as expected for overwrite_a to work correctly,
@@ -74,7 +74,7 @@ def perform(self, node, inputs, outputs):
7474
raise
7575
else:
7676
x = np.full_like(x, np.nan)
77-
z[0] = x
77+
z[0] = x.astype(input_dtype)
7878

7979
def L_op(self, inputs, outputs, gradients):
8080
"""

0 commit comments

Comments
 (0)