Skip to content

Commit 1784965

Browse files
committed
Move push_elemwise_constants to post_fusion pass
1 parent ccca97d commit 1784965

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

pytensor/compile/mode.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,15 @@ def apply(self, fgraph):
262262
"fast_run",
263263
"fusion",
264264
"local_elemwise_fusion",
265-
position=49,
265+
position=48.7,
266+
)
267+
268+
optdb.register(
269+
"post_fusion",
270+
EquilibriumDB(),
271+
"fast_run",
272+
"fast_compile",
273+
position=48.8,
266274
)
267275

268276
# especially constant merge

pytensor/tensor/rewriting/elemwise.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def push_elemwise_constants(fgraph, node):
411411
contained scalar op.
412412
"""
413413
op = node.op
414-
415414
if not isinstance(op, Elemwise):
416415
return False
417416

@@ -465,7 +464,7 @@ def is_constant_scalar(x):
465464
)
466465

467466

468-
compile.optdb["specialize"].register(
467+
compile.optdb["post_fusion"].register(
469468
"push_elemwise_constants",
470469
push_elemwise_constants,
471470
"fast_run_numba",

0 commit comments

Comments
 (0)