Skip to content

Commit 75eef40

Browse files
ricardoV94Ian Schweer
authored and
Ian Schweer
committed
PyTorch inline constants in dispatch to avoid graph breaks
1 parent 52c0e89 commit 75eef40

File tree

4 files changed

+65
-13
lines changed

4 files changed

+65
-13
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.compile import PYTORCH
99
from pytensor.compile.builders import OpFromGraph
1010
from pytensor.compile.ops import DeepCopyOp
11+
from pytensor.graph.basic import Constant
1112
from pytensor.graph.fg import FunctionGraph
1213
from pytensor.ifelse import IfElse
1314
from pytensor.link.utils import fgraph_to_python
@@ -121,14 +122,23 @@ def arange(start, stop, step):
121122

122123

123124
@pytorch_funcify.register(Join)
124-
def pytorch_funcify_Join(op, **kwargs):
125-
def join(axis, *tensors):
126-
# tensors could also be tuples, and in this case they don't have a ndim
127-
tensors = [torch.tensor(tensor) for tensor in tensors]
125+
def pytorch_funcify_Join(op, node, **kwargs):
126+
axis = node.inputs[0]
128127

129-
return torch.cat(tensors, dim=axis)
128+
if isinstance(axis, Constant):
129+
axis = int(axis.data)
130130

131-
return join
131+
def join_constant_axis(_, *tensors):
132+
return torch.cat(tensors, dim=axis)
133+
134+
return join_constant_axis
135+
136+
else:
137+
138+
def join(axis, *tensors):
139+
return torch.cat(tensors, dim=axis)
140+
141+
return join
132142

133143

134144
@pytorch_funcify.register(Eye)
@@ -173,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
173183
@pytorch_funcify.register(OpFromGraph)
174184
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
175185
kwargs.pop("storage_map", None)
176-
177186
# Apply inner rewrites
178187
PYTORCH.optimizer(op.fgraph)
179188
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
@@ -190,7 +199,19 @@ def tensorfromscalar(x):
190199

191200
@pytorch_funcify.register(Split)
192201
def pytorch_funcify_Split(op, node, **kwargs):
193-
def inner_fn(x, dim, split_amounts):
194-
return x.split(split_amounts.tolist(), dim=dim.item())
202+
x, dim, split_sizes = node.inputs
203+
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
204+
dim = int(dim.data)
205+
split_sizes = tuple(int(size) for size in split_sizes.data)
206+
207+
def split_constant_axis_and_sizes(x, *_):
208+
return x.split(split_sizes, dim=dim)
209+
210+
return split_constant_axis_and_sizes
211+
212+
else:
213+
214+
def inner_fn(x, dim, split_amounts):
215+
return x.split(split_amounts.tolist(), dim=dim.item())
195216

196-
return inner_fn
217+
return inner_fn

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import torch
22

3+
from pytensor.graph.basic import Constant
34
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
45
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
56

67

78
@pytorch_funcify.register(Reshape)
89
def pytorch_funcify_Reshape(op, node, **kwargs):
9-
def reshape(x, shape):
10-
return torch.reshape(x, tuple(shape))
10+
_, shape = node.inputs
1111

12-
return reshape
12+
if isinstance(shape, Constant):
13+
constant_shape = tuple(int(dim) for dim in shape.data)
14+
15+
def reshape_constant_shape(x, *_):
16+
return torch.reshape(x, constant_shape)
17+
18+
return reshape_constant_shape
19+
20+
else:
21+
22+
def reshape(x, shape):
23+
return torch.reshape(x, tuple(shape))
24+
25+
return reshape
1326

1427

1528
@pytorch_funcify.register(Shape)

pytensor/link/pytorch/dispatch/subtensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pytensor.graph.basic import Constant
12
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
23
from pytensor.tensor.subtensor import (
34
AdvancedIncSubtensor,
@@ -23,7 +24,21 @@ def check_negative_steps(indices):
2324
@pytorch_funcify.register(Subtensor)
2425
def pytorch_funcify_Subtensor(op, node, **kwargs):
2526
idx_list = op.idx_list
27+
x, *idxs = node.inputs
2628

29+
if all(isinstance(idx, Constant) for idx in idxs):
30+
# Use constant indices to avoid graph break
31+
constant_indices = indices_from_subtensor(
32+
[int(idx.data) for idx in idxs], idx_list
33+
)
34+
check_negative_steps(constant_indices)
35+
36+
def constant_index_subtensor(x, *_):
37+
return x[constant_indices]
38+
39+
return constant_index_subtensor
40+
41+
# Fallback that will introduce a graph break
2742
def subtensor(x, *flattened_indices):
2843
indices = indices_from_subtensor(flattened_indices, idx_list)
2944
check_negative_steps(indices)

pytensor/link/pytorch/linker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
3737
def jit_compile(self, fn):
3838
import torch
3939

40+
# flag that tend to help our graphs
41+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
42+
4043
from pytensor.link.pytorch.dispatch import pytorch_typify
4144

4245
class wrapper:

0 commit comments

Comments
 (0)