Skip to content

Commit 143ded6

Browse files
jessegrabowskiIan Schweer
authored and
Ian Schweer
committed
Add OpFromGraph wrapper around alloc_diag (pymc-devs#915)
* Add `OpFromGraph` wrapper around `alloc_diag` * Remove depreciated `AllocDiag` `Op`, rename `AllocDiag2 -> AllocDiag` * Set `inline = False` * Add rewrite to inline all `OpFromGraph` `Op`s * Add `is_zero_offset` helper to `Eye` * Add `is_left_expand_dims` and `is_right_expand_dims` attributes to `DimShuffle` * Seed `test_local_lift_through_linalg` test
1 parent 79232b2 commit 143ded6

File tree

10 files changed

+278
-173
lines changed

10 files changed

+278
-173
lines changed

pytensor/compile/builders.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from pytensor.compile.function import function
1010
from pytensor.compile.function.pfunc import rebuild_collect_shared
11-
from pytensor.compile.mode import optdb
1211
from pytensor.compile.sharedvalue import SharedVariable
1312
from pytensor.configdefaults import config
1413
from pytensor.gradient import DisconnectedType, Rop, grad
@@ -24,7 +23,6 @@
2423
from pytensor.graph.null_type import NullType
2524
from pytensor.graph.op import HasInnerGraph, Op
2625
from pytensor.graph.replace import clone_replace
27-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
2826
from pytensor.graph.utils import MissingInputError
2927

3028

@@ -575,7 +573,7 @@ def lop_overrides(inps, grads):
575573
for inp_grad in input_grads
576574
if not isinstance(inp_grad.type, DisconnectedType | NullType)
577575
]
578-
lop_op = type(self)(
576+
lop_op = OpFromGraph(
579577
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
580578
outputs=connected_input_grads,
581579
inline=self.is_inline,
@@ -669,7 +667,7 @@ def _build_and_cache_rop_op(self):
669667
for out_grad in output_grads
670668
if not isinstance(out_grad.type, DisconnectedType | NullType)
671669
]
672-
rop_op = type(self)(
670+
rop_op = OpFromGraph(
673671
inputs=inner_inputs + eval_points,
674672
outputs=filtered_output_grads,
675673
inline=self.is_inline,
@@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs):
852850
assert len(variables) == len(outputs)
853851
for output, variable in zip(outputs, variables):
854852
output[0] = variable
855-
856-
857-
@node_rewriter([OpFromGraph])
858-
def inline_ofg_expansion(fgraph, node):
859-
"""
860-
This optimization expands internal graph of OpFromGraph.
861-
Only performed if node.op.is_inline == True
862-
Doing so can improve optimization at the cost of compilation speed.
863-
"""
864-
op = node.op
865-
if not isinstance(op, OpFromGraph):
866-
return False
867-
if not op.is_inline:
868-
return False
869-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
870-
871-
872-
# We want to run this before the first merge optimizer
873-
# and before the first scan optimizer.
874-
optdb.register(
875-
"inline_ofg_expansion",
876-
in2out(inline_ofg_expansion),
877-
"fast_compile",
878-
"fast_run",
879-
position=-0.01,
880-
)

pytensor/link/jax/dispatch/basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import warnings
2+
from collections.abc import Callable
23
from functools import singledispatch
34

45
import jax
56
import jax.numpy as jnp
67
import numpy as np
78

9+
from pytensor.compile import JAX
10+
from pytensor.compile.builders import OpFromGraph
811
from pytensor.compile.ops import DeepCopyOp, ViewOp
912
from pytensor.configdefaults import config
1013
from pytensor.graph.fg import FunctionGraph
@@ -114,3 +117,24 @@ def viewop(x):
114117
return x
115118

116119
return viewop
120+
121+
122+
@jax_funcify.register(OpFromGraph)
123+
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
124+
_ = kwargs.pop("storage_map", None)
125+
126+
# Apply inner rewrites
127+
JAX.optimizer(ofg.fgraph)
128+
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)
129+
130+
if len(ofg.fgraph.outputs) == 1:
131+
132+
def opfromgraph(*inputs):
133+
return fgraph_fn(*inputs)[0]
134+
135+
else:
136+
137+
def opfromgraph(*inputs):
138+
return fgraph_fn(*inputs)
139+
140+
return opfromgraph

pytensor/tensor/basic.py

Lines changed: 46 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytensor.scalar.sharedvar
2222
from pytensor import compile, config, printing
2323
from pytensor import scalar as ps
24+
from pytensor.compile.builders import OpFromGraph
2425
from pytensor.gradient import DisconnectedType, grad_undefined
2526
from pytensor.graph import RewriteDatabaseQuery
2627
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
@@ -1334,6 +1335,25 @@ def infer_shape(self, fgraph, node, in_shapes):
13341335
def grad(self, inp, grads):
13351336
return [grad_undefined(self, i, inp[i]) for i in range(3)]
13361337

1338+
@staticmethod
1339+
def is_offset_zero(node) -> bool:
1340+
"""
1341+
Test if an Eye Op has a diagonal offset of zero
1342+
1343+
Parameters
1344+
----------
1345+
node
1346+
Eye node to test
1347+
1348+
Returns
1349+
-------
1350+
is_offset_zero: bool
1351+
True if the offset is zero (``k = 0``).
1352+
"""
1353+
1354+
offset = node.inputs[-1]
1355+
return isinstance(offset, Constant) and offset.data.item() == 0
1356+
13371357

13381358
def eye(n, m=None, k=0, dtype=None):
13391359
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
@@ -3749,109 +3769,37 @@ def trace(a, offset=0, axis1=0, axis2=1):
37493769
return diagonal(a, offset=offset, axis1=axis1, axis2=axis2).sum(-1)
37503770

37513771

3752-
class AllocDiag(Op):
3753-
"""An `Op` that copies a vector to the diagonal of a zero-ed matrix."""
3772+
class AllocDiag(OpFromGraph):
3773+
"""
3774+
Wrapper Op for alloc_diag graphs
3775+
"""
37543776

3755-
__props__ = ("offset", "axis1", "axis2")
3777+
__props__ = ("axis1", "axis2")
37563778

3757-
def __init__(self, offset=0, axis1=0, axis2=1):
3758-
"""
3759-
Parameters
3760-
----------
3761-
offset: int
3762-
Offset of the diagonal from the main diagonal defined by `axis1`
3763-
and `axis2`. Can be positive or negative. Defaults to main
3764-
diagonal (i.e. 0).
3765-
axis1: int
3766-
Axis to be used as the first axis of the 2-D sub-arrays to which
3767-
the diagonals will be allocated. Defaults to first axis (i.e. 0).
3768-
axis2: int
3769-
Axis to be used as the second axis of the 2-D sub-arrays to which
3770-
the diagonals will be allocated. Defaults to second axis (i.e. 1).
3771-
"""
3772-
warnings.warn(
3773-
"AllocDiag is deprecated. Use `alloc_diag` instead",
3774-
FutureWarning,
3775-
)
3776-
self.offset = offset
3777-
if axis1 < 0 or axis2 < 0:
3778-
raise NotImplementedError("AllocDiag does not support negative axis")
3779-
if axis1 == axis2:
3780-
raise ValueError("axis1 and axis2 cannot be the same")
3779+
def __init__(self, *args, axis1, axis2, offset, **kwargs):
37813780
self.axis1 = axis1
37823781
self.axis2 = axis2
3782+
self.offset = offset
37833783

3784-
def make_node(self, diag):
3785-
diag = as_tensor_variable(diag)
3786-
if diag.type.ndim < 1:
3787-
raise ValueError(
3788-
"AllocDiag needs an input with 1 or more dimensions", diag.type
3789-
)
3790-
return Apply(
3791-
self,
3792-
[diag],
3793-
[diag.type.clone(shape=(None,) * (diag.ndim + 1))()],
3794-
)
3795-
3796-
def perform(self, node, inputs, outputs):
3797-
(x,) = inputs
3798-
(z,) = outputs
3799-
3800-
axis1 = np.minimum(self.axis1, self.axis2)
3801-
axis2 = np.maximum(self.axis1, self.axis2)
3802-
offset = self.offset
3803-
3804-
# Create array with one extra dimension for resulting matrix
3805-
result_shape = x.shape[:-1] + (x.shape[-1] + abs(offset),) * 2
3806-
result = np.zeros(result_shape, dtype=x.dtype)
3807-
3808-
# Create slice for diagonal in final 2 axes
3809-
idxs = np.arange(x.shape[-1])
3810-
diagonal_slice = (len(result_shape) - 2) * [slice(None)] + [
3811-
idxs + np.maximum(0, -offset),
3812-
idxs + np.maximum(0, offset),
3813-
]
3814-
3815-
# Fill in final 2 axes with x
3816-
result[tuple(diagonal_slice)] = x
3817-
3818-
if len(x.shape) > 1:
3819-
# Re-order axes so they correspond to diagonals at axis1, axis2
3820-
axes = list(range(len(x.shape[:-1])))
3821-
last_idx = axes[-1]
3822-
axes = axes[:axis1] + [last_idx + 1] + axes[axis1:]
3823-
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
3824-
result = result.transpose(axes)
3825-
3826-
z[0] = result
3827-
3828-
def grad(self, inputs, gout):
3829-
(gz,) = gout
3830-
return [diagonal(gz, offset=self.offset, axis1=self.axis1, axis2=self.axis2)]
3831-
3832-
def infer_shape(self, fgraph, nodes, shapes):
3833-
(x_shape,) = shapes
3834-
axis1 = np.minimum(self.axis1, self.axis2)
3835-
axis2 = np.maximum(self.axis1, self.axis2)
3784+
super().__init__(*args, **kwargs, strict=True)
38363785

3837-
result_shape = list(x_shape[:-1])
3838-
diag_shape = x_shape[-1] + abs(self.offset)
3839-
result_shape = result_shape[:axis1] + [diag_shape] + result_shape[axis1:]
3840-
result_shape = result_shape[:axis2] + [diag_shape] + result_shape[axis2:]
3841-
return [tuple(result_shape)]
3786+
@staticmethod
3787+
def is_offset_zero(node) -> bool:
3788+
"""
3789+
Test if an AllocDiag Op has a diagonal offset of zero
38423790
3843-
def __setstate__(self, state):
3844-
if "view_map" in state:
3845-
del state["view_map"]
3791+
Parameters
3792+
----------
3793+
node
3794+
AllocDiag node to test
38463795
3847-
self.__dict__.update(state)
3796+
Returns
3797+
-------
3798+
is_offset_zero: bool
3799+
True if the offset is zero (``k = 0``).
3800+
"""
38483801

3849-
if "offset" not in state:
3850-
self.offset = 0
3851-
if "axis1" not in state:
3852-
self.axis1 = 0
3853-
if "axis2" not in state:
3854-
self.axis2 = 1
3802+
return node.op.offset == 0
38553803

38563804

38573805
def alloc_diag(diag, offset=0, axis1=0, axis2=1):
@@ -3862,6 +3810,7 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38623810
from pytensor.tensor import set_subtensor
38633811

38643812
diag = as_tensor_variable(diag)
3813+
38653814
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=diag.type.ndim + 1)
38663815
if axis1 > axis2:
38673816
axis1, axis2 = axis2, axis1
@@ -3888,7 +3837,9 @@ def alloc_diag(diag, offset=0, axis1=0, axis2=1):
38883837
axes = axes[:axis2] + [last_idx + 2] + axes[axis2:]
38893838
result = result.transpose(axes)
38903839

3891-
return result
3840+
return AllocDiag(
3841+
inputs=[diag], outputs=[result], axis1=axis1, axis2=axis2, offset=offset
3842+
)(diag)
38923843

38933844

38943845
def diag(v, k=0):

pytensor/tensor/elemwise.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ def __init__(self, input_broadcastable, new_order):
185185
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
186186
self.drop = drop
187187

188+
input_ndim = len(input_broadcastable)
189+
self.is_left_expand_dims = self.augment and (
190+
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
191+
)
192+
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
193+
range(input_ndim)
194+
)
195+
188196
if self.inplace:
189197
self.view_map = {0: [0]}
190198

pytensor/tensor/rewriting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor.tensor.rewriting.jax
1111
import pytensor.tensor.rewriting.linalg
1212
import pytensor.tensor.rewriting.math
13+
import pytensor.tensor.rewriting.ofg
1314
import pytensor.tensor.rewriting.shape
1415
import pytensor.tensor.rewriting.special
1516
import pytensor.tensor.rewriting.subtensor

0 commit comments

Comments
 (0)