4
4
import numpy as np
5
5
6
6
import pytensor
7
+ from pytensor .compile import optdb
7
8
from pytensor .graph import node_rewriter
8
9
from pytensor .graph .basic import Constant
10
+ from pytensor .graph .rewriting .basic import in2out
11
+ from pytensor .graph .rewriting .db import SequenceDB
9
12
from pytensor .link .jax .dispatch .basic import jax_funcify
10
13
from pytensor .tensor import get_vector_length
11
14
from pytensor .tensor .basic import (
24
27
get_underlying_scalar_constant_value ,
25
28
)
26
29
from pytensor .tensor .exceptions import NotScalarConstantError
27
- from pytensor .tensor .rewriting .basic import register_specialize
28
30
from pytensor .tensor .shape import Shape_i
29
31
30
32
@@ -211,7 +213,6 @@ def tri(*args):
211
213
return tri
212
214
213
215
214
- @register_specialize
215
216
@node_rewriter ([AllocDiag ])
216
217
def eagerly_inline_alloc_diag (fgraph , node ):
217
218
"""
@@ -235,3 +236,14 @@ def eagerly_inline_alloc_diag(fgraph, node):
235
236
inline = pytensor .clone_replace (output , {inner_input : input })
236
237
237
238
return [inline ]
239
+
240
+
241
+ remove_alloc_ofg_opt = SequenceDB ()
242
+ remove_alloc_ofg_opt .register (
243
+ "inline_alloc_diag" ,
244
+ in2out (eagerly_inline_alloc_diag ),
245
+ "jax" ,
246
+ )
247
+
248
+ # Do this right away so other JAX rewrites can act on the inner graph
249
+ optdb .register ("jax_inline_alloc_diag" , remove_alloc_ofg_opt , "jax" , position = 0 )
0 commit comments