Skip to content

Commit e10b951

Browse files
committed
Don't run useless fusion and inplace rewrites in JAX mode
1 parent 36161e8 commit e10b951

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pytensor/compile/mode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
454454

455455
JAX = Mode(
456456
JAXLinker(),
457-
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
457+
RewriteDatabaseQuery(
458+
include=["fast_run", "jax"],
459+
exclude=["cxx_only", "BlasOpt", "fusion", "inplace"],
460+
),
458461
)
459462
NUMBA = Mode(
460463
NumbaLinker(),

0 commit comments

Comments
 (0)