Skip to content

Commit c08d288

Browse files
committed
PyTorch inline constants in dispatch to avoid graph breaks
1 parent 231a977 commit c08d288

File tree

4 files changed

+74
-9
lines changed

4 files changed

+74
-9
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.compile import PYTORCH
99
from pytensor.compile.builders import OpFromGraph
1010
from pytensor.compile.ops import DeepCopyOp
11+
from pytensor.graph.basic import Constant
1112
from pytensor.graph.fg import FunctionGraph
1213
from pytensor.ifelse import IfElse
1314
from pytensor.link.utils import fgraph_to_python
@@ -19,6 +20,7 @@
1920
Eye,
2021
Join,
2122
MakeVector,
23+
Split,
2224
TensorFromScalar,
2325
)
2426

@@ -120,14 +122,23 @@ def arange(start, stop, step):
120122

121123

122124
@pytorch_funcify.register(Join)
123-
def pytorch_funcify_Join(op, **kwargs):
124-
def join(axis, *tensors):
125-
# tensors could also be tuples, and in this case they don't have a ndim
126-
tensors = [torch.tensor(tensor) for tensor in tensors]
125+
def pytorch_funcify_Join(op, node, **kwargs):
126+
axis = node.inputs[0]
127127

128-
return torch.cat(tensors, dim=axis)
128+
if isinstance(axis, Constant):
129+
axis = int(axis.data)
129130

130-
return join
131+
def join_constant_axis(_, *tensors):
132+
return torch.cat(tensors, dim=axis)
133+
134+
return join_constant_axis
135+
136+
else:
137+
138+
def join(axis, *tensors):
139+
return torch.cat(tensors, dim=axis)
140+
141+
return join
131142

132143

133144
@pytorch_funcify.register(Eye)
@@ -185,3 +196,23 @@ def tensorfromscalar(x):
185196
return torch.as_tensor(x)
186197

187198
return tensorfromscalar
199+
200+
201+
@pytorch_funcify.register(Split)
202+
def pytorch_funcify_Split(op, node, **kwargs):
203+
x, dim, split_sizes = node.inputs
204+
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
205+
dim = int(dim.data)
206+
split_sizes = tuple(int(size) for size in split_sizes.data)
207+
208+
def split_constant_axis_and_sizes(x, *_):
209+
return x.split(split_sizes, dim=dim)
210+
211+
return split_constant_axis_and_sizes
212+
213+
else:
214+
215+
def inner_fn(x, dim, split_amounts):
216+
return x.split(split_amounts.tolist(), dim=dim.item())
217+
218+
return inner_fn

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
66
from pytensor.scalar.basic import (
77
Cast,
8+
Invert,
89
ScalarOp,
910
)
1011
from pytensor.scalar.loop import ScalarLoop
1112
from pytensor.scalar.math import Softplus
1213

1314

15+
@pytorch_funcify.register(Invert)
16+
def pytorch_funcify_invert(op, node, **kwargs):
17+
return torch.bitwise_not
18+
19+
1420
@pytorch_funcify.register(ScalarOp)
1521
def pytorch_funcify_ScalarOp(op, node, **kwargs):
1622
"""Return pytorch function that implements the same computation as the Scalar Op.

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import torch
22

3+
from pytensor.graph.basic import Constant
34
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
45
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
56

67

78
@pytorch_funcify.register(Reshape)
89
def pytorch_funcify_Reshape(op, node, **kwargs):
9-
def reshape(x, shape):
10-
return torch.reshape(x, tuple(shape))
10+
_, shape = node.inputs
1111

12-
return reshape
12+
if isinstance(shape, Constant):
13+
constant_shape = tuple(int(dim) for dim in shape.data)
14+
15+
def reshape_constant_shape(x, *_):
16+
return torch.reshape(x, constant_shape)
17+
18+
return reshape_constant_shape
19+
20+
else:
21+
22+
def reshape(x, shape):
23+
return torch.reshape(x, tuple(shape))
24+
25+
return reshape
1326

1427

1528
@pytorch_funcify.register(Shape)

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pytensor.graph.basic import Constant
12
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
23
from pytensor.tensor.subtensor import (
34
AdvancedIncSubtensor,
@@ -23,7 +24,21 @@ def check_negative_steps(indices):
2324
@pytorch_funcify.register(Subtensor)
2425
def pytorch_funcify_Subtensor(op, node, **kwargs):
2526
idx_list = op.idx_list
27+
x, *idxs = node.inputs
2628

29+
if all(isinstance(idx, Constant) for idx in idxs):
30+
# Use constant indices to avoid graph break
31+
constant_indices = indices_from_subtensor(
32+
[int(idx.data) for idx in idxs], idx_list
33+
)
34+
check_negative_steps(constant_indices)
35+
36+
def constant_index_subtensor(x, *_):
37+
return x[constant_indices]
38+
39+
return constant_index_subtensor
40+
41+
# Fallback that will introduce a graph break
2742
def subtensor(x, *flattened_indices):
2843
indices = indices_from_subtensor(flattened_indices, idx_list)
2944
check_negative_steps(indices)

0 commit comments

Comments
 (0)