8
8
from pytensor .compile import PYTORCH
9
9
from pytensor .compile .builders import OpFromGraph
10
10
from pytensor .compile .ops import DeepCopyOp
11
+ from pytensor .graph .basic import Constant
11
12
from pytensor .graph .fg import FunctionGraph
12
13
from pytensor .ifelse import IfElse
13
14
from pytensor .link .utils import fgraph_to_python
19
20
Eye ,
20
21
Join ,
21
22
MakeVector ,
23
+ Split ,
22
24
TensorFromScalar ,
23
25
)
24
26
@@ -120,14 +122,23 @@ def arange(start, stop, step):
120
122
121
123
122
124
@pytorch_funcify .register (Join )
123
- def pytorch_funcify_Join (op , ** kwargs ):
124
- def join (axis , * tensors ):
125
- # tensors could also be tuples, and in this case they don't have a ndim
126
- tensors = [torch .tensor (tensor ) for tensor in tensors ]
125
+ def pytorch_funcify_Join (op , node , ** kwargs ):
126
+ axis = node .inputs [0 ]
127
127
128
- return torch .cat (tensors , dim = axis )
128
+ if isinstance (axis , Constant ):
129
+ axis = int (axis .data )
129
130
130
- return join
131
+ def join_constant_axis (_ , * tensors ):
132
+ return torch .cat (tensors , dim = axis )
133
+
134
+ return join_constant_axis
135
+
136
+ else :
137
+
138
+ def join (axis , * tensors ):
139
+ return torch .cat (tensors , dim = axis )
140
+
141
+ return join
131
142
132
143
133
144
@pytorch_funcify .register (Eye )
@@ -172,7 +183,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs):
172
183
@pytorch_funcify .register (OpFromGraph )
173
184
def pytorch_funcify_OpFromGraph (op , node , ** kwargs ):
174
185
kwargs .pop ("storage_map" , None )
175
-
176
186
# Apply inner rewrites
177
187
PYTORCH .optimizer (op .fgraph )
178
188
fgraph_fn = pytorch_funcify (op .fgraph , ** kwargs , squeeze_output = True )
@@ -185,3 +195,23 @@ def tensorfromscalar(x):
185
195
return torch .as_tensor (x )
186
196
187
197
return tensorfromscalar
198
+
199
+
200
+ @pytorch_funcify .register (Split )
201
+ def pytorch_funcify_Split (op , node , ** kwargs ):
202
+ x , dim , split_sizes = node .inputs
203
+ if isinstance (dim , Constant ) and isinstance (split_sizes , Constant ):
204
+ dim = int (dim .data )
205
+ split_sizes = tuple (int (size ) for size in split_sizes .data )
206
+
207
+ def split_constant_axis_and_sizes (x , * _ ):
208
+ return x .split (split_sizes , dim = dim )
209
+
210
+ return split_constant_axis_and_sizes
211
+
212
+ else :
213
+
214
+ def inner_fn (x , dim , split_amounts ):
215
+ return x .split (split_amounts .tolist (), dim = dim .item ())
216
+
217
+ return inner_fn
0 commit comments