File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed
pytensor/link/jax/dispatch Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change 3
3
import jax .numpy as jnp
4
4
import numpy as np
5
5
6
+ import pytensor
7
+ from pytensor .graph import node_rewriter
6
8
from pytensor .graph .basic import Constant
7
9
from pytensor .link .jax .dispatch .basic import jax_funcify
8
10
from pytensor .tensor import get_vector_length
9
11
from pytensor .tensor .basic import (
10
12
Alloc ,
13
+ AllocDiag2 ,
11
14
AllocEmpty ,
12
15
ARange ,
13
16
ExtractDiag ,
21
24
get_underlying_scalar_constant_value ,
22
25
)
23
26
from pytensor .tensor .exceptions import NotScalarConstantError
27
+ from pytensor .tensor .rewriting .basic import register_specialize
24
28
from pytensor .tensor .shape import Shape_i
25
29
26
30
@@ -205,3 +209,28 @@ def tri(*args):
205
209
return jnp .tri (* args , dtype = op .dtype )
206
210
207
211
return tri
212
+
213
+
214
+ @register_specialize
215
+ @node_rewriter ([AllocDiag2 ])
216
+ def eagerly_inline_alloc_diag (fgraph , node ):
217
+ """
218
+ Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified
219
+ Parameters
220
+ ----------
221
+ fgraph: FunctionGraph
222
+ The function graph being rewritten
223
+ node: Apply
224
+ Node of the function graph to be optimized
225
+
226
+ Returns
227
+ -------
228
+
229
+ """
230
+ [input ] = node .inputs
231
+ [output ] = node .op .inner_outputs
232
+ inner_input = output .owner .inputs [1 ]
233
+
234
+ inline = pytensor .clone_replace (output , {inner_input : input })
235
+
236
+ return [inline ]
You can’t perform that action at this time.
0 commit comments