File tree Expand file tree Collapse file tree 2 files changed +15
-0
lines changed
pytensor/link/pytorch/dispatch Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change 19
19
Eye ,
20
20
Join ,
21
21
MakeVector ,
22
+ Split ,
22
23
TensorFromScalar ,
23
24
)
24
25
@@ -185,3 +186,11 @@ def tensorfromscalar(x):
185
186
return torch .as_tensor (x )
186
187
187
188
return tensorfromscalar
189
+
190
+
191
+ @pytorch_funcify .register (Split )
192
+ def pytorch_funcify_Split (op , node , ** kwargs ):
193
+ def inner_fn (x , dim , split_amounts ):
194
+ return x .split (split_amounts .tolist (), dim = dim .item ())
195
+
196
+ return inner_fn
Original file line number Diff line number Diff line change 5
5
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
6
6
from pytensor .scalar .basic import (
7
7
Cast ,
8
+ Invert ,
8
9
ScalarOp ,
9
10
)
10
11
from pytensor .scalar .loop import ScalarLoop
11
12
from pytensor .scalar .math import Softplus
12
13
13
14
15
+ @pytorch_funcify .register (Invert )
16
+ def pytorch_funcify_invert (op , node , ** kwargs ):
17
+ return torch .bitwise_not
18
+
19
+
14
20
@pytorch_funcify .register (ScalarOp )
15
21
def pytorch_funcify_ScalarOp (op , node , ** kwargs ):
16
22
"""Return pytorch function that implements the same computation as the Scalar Op.
You can’t perform that action at this time.
0 commit comments