Skip to content

Commit afe2a65

Browse files
Remove AllocDiag2 from graphs in the specialization phase, after other rewrites that need it have fired.
1 parent dbfe92c commit afe2a65

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import jax.numpy as jnp
44
import numpy as np
55

6+
import pytensor
7+
from pytensor.graph import node_rewriter
68
from pytensor.graph.basic import Constant
79
from pytensor.link.jax.dispatch.basic import jax_funcify
810
from pytensor.tensor import get_vector_length
911
from pytensor.tensor.basic import (
1012
Alloc,
13+
AllocDiag2,
1114
AllocEmpty,
1215
ARange,
1316
ExtractDiag,
@@ -21,6 +24,7 @@
2124
get_underlying_scalar_constant_value,
2225
)
2326
from pytensor.tensor.exceptions import NotScalarConstantError
27+
from pytensor.tensor.rewriting.basic import register_specialize
2428
from pytensor.tensor.shape import Shape_i
2529

2630

@@ -205,3 +209,28 @@ def tri(*args):
205209
return jnp.tri(*args, dtype=op.dtype)
206210

207211
return tri
212+
213+
214+
@register_specialize
215+
@node_rewriter([AllocDiag2])
216+
def eagerly_inline_alloc_diag(fgraph, node):
217+
"""
218+
Inline `AllocDiag2` OpFromGraph into the graph so the component Ops can themselves be jaxified
219+
Parameters
220+
----------
221+
fgraph: FunctionGraph
222+
The function graph being rewritten
223+
node: Apply
224+
Node of the function graph to be optimized
225+
226+
Returns
227+
-------
228+
229+
"""
230+
[input] = node.inputs
231+
[output] = node.op.inner_outputs
232+
inner_input = output.owner.inputs[1]
233+
234+
inline = pytensor.clone_replace(output, {inner_input: input})
235+
236+
return [inline]

0 commit comments

Comments
 (0)