Skip to content

Commit 6e37d26

Browse files
Correctly register eagerly_inline_alloc_diag as a JAX-only rewrite
1 parent a5355b6 commit 6e37d26

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import numpy as np
55

66
import pytensor
7+
from pytensor.compile import optdb
78
from pytensor.graph import node_rewriter
89
from pytensor.graph.basic import Constant
10+
from pytensor.graph.rewriting.basic import in2out
11+
from pytensor.graph.rewriting.db import SequenceDB
912
from pytensor.link.jax.dispatch.basic import jax_funcify
1013
from pytensor.tensor import get_vector_length
1114
from pytensor.tensor.basic import (
@@ -24,7 +27,6 @@
2427
get_underlying_scalar_constant_value,
2528
)
2629
from pytensor.tensor.exceptions import NotScalarConstantError
27-
from pytensor.tensor.rewriting.basic import register_specialize
2830
from pytensor.tensor.shape import Shape_i
2931

3032

@@ -211,7 +213,6 @@ def tri(*args):
211213
return tri
212214

213215

214-
@register_specialize
215216
@node_rewriter([AllocDiag])
216217
def eagerly_inline_alloc_diag(fgraph, node):
217218
"""
@@ -235,3 +236,14 @@ def eagerly_inline_alloc_diag(fgraph, node):
235236
inline = pytensor.clone_replace(output, {inner_input: input})
236237

237238
return [inline]
239+
240+
241+
remove_alloc_ofg_opt = SequenceDB()
242+
remove_alloc_ofg_opt.register(
243+
"inline_alloc_diag",
244+
in2out(eagerly_inline_alloc_diag),
245+
"jax",
246+
)
247+
248+
# Do this right away so other JAX rewrites can act on the inner graph
249+
optdb.register("jax_inline_alloc_diag", remove_alloc_ofg_opt, "jax", position=0)

0 commit comments

Comments
 (0)