Skip to content

Commit 4fa9bb8

Browse files
authored
PyTorch inline constants in dispatch to avoid graph breaks (#1118)
* Split and inverse * PyTorch inline constants in dispatch to avoid graph breaks
1 parent 17748b7 commit 4fa9bb8

File tree

6 files changed

+127
-10
lines changed

6 files changed

+127
-10
lines changed

pytensor/link/pytorch/dispatch/basic.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.compile import PYTORCH
99
from pytensor.compile.builders import OpFromGraph
1010
from pytensor.compile.ops import DeepCopyOp
11+
from pytensor.graph.basic import Constant
1112
from pytensor.graph.fg import FunctionGraph
1213
from pytensor.ifelse import IfElse
1314
from pytensor.link.utils import fgraph_to_python
@@ -19,6 +20,7 @@
1920
Eye,
2021
Join,
2122
MakeVector,
23+
Split,
2224
TensorFromScalar,
2325
)
2426

@@ -120,14 +122,23 @@ def arange(start, stop, step):
120122

121123

122124
@pytorch_funcify.register(Join)
123-
def pytorch_funcify_Join(op, **kwargs):
124-
def join(axis, *tensors):
125-
# tensors could also be tuples, and in this case they don't have a ndim
126-
tensors = [torch.tensor(tensor) for tensor in tensors]
125+
def pytorch_funcify_Join(op, node, **kwargs):
126+
axis = node.inputs[0]
127127

128-
return torch.cat(tensors, dim=axis)
128+
if isinstance(axis, Constant):
129+
axis = int(axis.data)
129130

130-
return join
131+
def join_constant_axis(_, *tensors):
132+
return torch.cat(tensors, dim=axis)
133+
134+
return join_constant_axis
135+
136+
else:
137+
138+
def join(axis, *tensors):
139+
return torch.cat(tensors, dim=axis)
140+
141+
return join
131142

132143

133144
@pytorch_funcify.register(Eye)
@@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
172183
@pytorch_funcify.register(OpFromGraph)
173184
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
174185
kwargs.pop("storage_map", None)
175-
176186
# Apply inner rewrites
177187
PYTORCH.optimizer(op.fgraph)
178188
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
@@ -185,3 +195,23 @@ def tensorfromscalar(x):
185195
return torch.as_tensor(x)
186196

187197
return tensorfromscalar
198+
199+
200+
@pytorch_funcify.register(Split)
201+
def pytorch_funcify_Split(op, node, **kwargs):
202+
x, dim, split_sizes = node.inputs
203+
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
204+
dim = int(dim.data)
205+
split_sizes = tuple(int(size) for size in split_sizes.data)
206+
207+
def split_constant_axis_and_sizes(x, *_):
208+
return x.split(split_sizes, dim=dim)
209+
210+
return split_constant_axis_and_sizes
211+
212+
else:
213+
214+
def inner_fn(x, dim, split_amounts):
215+
return x.split(split_amounts.tolist(), dim=dim.item())
216+
217+
return inner_fn

pytensor/link/pytorch/dispatch/scalar.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
66
from pytensor.scalar.basic import (
77
Cast,
8+
Invert,
89
ScalarOp,
910
)
1011
from pytensor.scalar.loop import ScalarLoop
1112
from pytensor.scalar.math import Softplus
1213

1314

15+
@pytorch_funcify.register(Invert)
16+
def pytorch_funcify_invert(op, node, **kwargs):
17+
return torch.bitwise_not
18+
19+
1420
@pytorch_funcify.register(ScalarOp)
1521
def pytorch_funcify_ScalarOp(op, node, **kwargs):
1622
"""Return pytorch function that implements the same computation as the Scalar Op.

pytensor/link/pytorch/dispatch/shape.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
import torch
22

3+
from pytensor.graph.basic import Constant
34
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
45
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
56

67

78
@pytorch_funcify.register(Reshape)
89
def pytorch_funcify_Reshape(op, node, **kwargs):
9-
def reshape(x, shape):
10-
return torch.reshape(x, tuple(shape))
10+
_, shape = node.inputs
1111

12-
return reshape
12+
if isinstance(shape, Constant):
13+
constant_shape = tuple(int(dim) for dim in shape.data)
14+
15+
def reshape_constant_shape(x, *_):
16+
return torch.reshape(x, constant_shape)
17+
18+
return reshape_constant_shape
19+
20+
else:
21+
22+
def reshape(x, shape):
23+
return torch.reshape(x, tuple(shape))
24+
25+
return reshape
1326

1427

1528
@pytorch_funcify.register(Shape)

pytensor/link/pytorch/dispatch/subtensor.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pytensor.graph.basic import Constant
12
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
23
from pytensor.tensor.subtensor import (
34
AdvancedIncSubtensor,
@@ -23,7 +24,21 @@ def check_negative_steps(indices):
2324
@pytorch_funcify.register(Subtensor)
2425
def pytorch_funcify_Subtensor(op, node, **kwargs):
2526
idx_list = op.idx_list
27+
x, *idxs = node.inputs
2628

29+
if all(isinstance(idx, Constant) for idx in idxs):
30+
# Use constant indices to avoid graph break
31+
constant_indices = indices_from_subtensor(
32+
[int(idx.data) for idx in idxs], idx_list
33+
)
34+
check_negative_steps(constant_indices)
35+
36+
def constant_index_subtensor(x, *_):
37+
return x[constant_indices]
38+
39+
return constant_index_subtensor
40+
41+
# Fallback that will introduce a graph break
2742
def subtensor(x, *flattened_indices):
2843
indices = indices_from_subtensor(flattened_indices, idx_list)
2944
check_negative_steps(indices)

pytensor/link/pytorch/linker.py

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
3737
def jit_compile(self, fn):
3838
import torch
3939

40+
# flag that tend to help our graphs
41+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
42+
4043
from pytensor.link.pytorch.dispatch import pytorch_typify
4144

4245
class wrapper:

tests/link/pytorch/test_basic.py

+50
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,53 @@ def test_ScalarLoop_Elemwise_multi_carries():
471471
compare_pytorch_and_py(
472472
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
473473
)
474+
475+
476+
rng = np.random.default_rng(42849)
477+
478+
479+
@pytest.mark.parametrize(
480+
"n_splits, axis, values, sizes",
481+
[
482+
(
483+
0,
484+
0,
485+
rng.normal(size=20).astype(config.floatX),
486+
[],
487+
),
488+
(
489+
5,
490+
0,
491+
rng.normal(size=5).astype(config.floatX),
492+
rng.multinomial(5, np.ones(5) / 5),
493+
),
494+
(
495+
5,
496+
0,
497+
rng.normal(size=10).astype(config.floatX),
498+
rng.multinomial(10, np.ones(5) / 5),
499+
),
500+
(
501+
5,
502+
-1,
503+
rng.normal(size=(11, 7)).astype(config.floatX),
504+
rng.multinomial(7, np.ones(5) / 5),
505+
),
506+
(
507+
5,
508+
-2,
509+
rng.normal(size=(11, 7)).astype(config.floatX),
510+
rng.multinomial(11, np.ones(5) / 5),
511+
),
512+
],
513+
)
514+
def test_Split(n_splits, axis, values, sizes):
515+
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
516+
s = pt.vector("s", dtype="int64")
517+
g = pt.split(i, s, n_splits, axis=axis)
518+
assert len(g) == n_splits
519+
if n_splits == 0:
520+
return
521+
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
522+
523+
compare_pytorch_and_py(g_fg, [values, sizes])

0 commit comments

Comments
 (0)