Skip to content

Commit d546e5b

Browse files
Skip rewrite tests when mode=FAST_COMPILE
1 parent 5b96d83 commit d546e5b

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,6 @@ def softplus(x):
812812
@numba_funcify.register(Cholesky)
813813
def numba_funcify_Cholesky(op, node, **kwargs):
814814
lower = op.lower
815-
816815
out_dtype = node.outputs[0].type.numpy_dtype
817816

818817
if lower:

pytensor/tensor/slinalg.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class Cholesky(Op):
4242
`scipy.linalg.LinAlgError` if the matrix is not positive definite.
4343
If on_error is set to 'nan', it will return a matrix containing
4444
nans instead.
45+
overwrite_a: bool, ignored
46+
Whether to use the same memory for the output as `a`. This argument is ignored, and
47+
included only for consistency with scipy.linalg.cholesky.
4548
"""
4649

4750
# TODO: for specific dtypes
@@ -52,11 +55,13 @@ class Cholesky(Op):
5255

5356
def __init__(self, *, lower=True, on_error="raise", overwrite_a=False):
5457
self.lower = lower
55-
self.overwrite_a = overwrite_a
58+
5659
if on_error not in ("raise", "nan"):
5760
raise ValueError('on_error must be one of "raise" or ""nan"')
5861
self.on_error = on_error
59-
if overwrite_a:
62+
63+
self.overwrite_a = overwrite_a
64+
if self.overwrite_a:
6065
self.destroy_map = {0: [0]}
6166

6267
def infer_shape(self, fgraph, node, shapes):
@@ -87,7 +92,7 @@ def perform(self, node, inputs, outputs):
8792
if self.on_error == "raise":
8893
raise
8994
else:
90-
x = (np.zeros(x.shape) * np.nan).astype(x.dtype)
95+
x = np.full_like(x, np.nan)
9196
z[0] = x
9297

9398
def L_op(self, inputs, outputs, gradients):

tests/tensor/rewriting/test_linalg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ def test_invalid_batched_a(self):
308308
)
309309

310310

311+
@pytest.mark.skipif(
312+
config.mode == "FAST_COMPILE",
313+
reason="inplace rewrites disabled when mode is FAST_COMPILE",
314+
)
311315
def test_local_inplace_cholesky():
312316
X = matrix("X")
313317
L = cholesky(X, overwrite_a=False, lower=True)

0 commit comments

Comments
 (0)