Skip to content

Commit 7ce2a5d

Browse files
committed
Split and inverse
1 parent 231a977 commit 7ce2a5d

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Eye,
2020
Join,
2121
MakeVector,
22+
Split,
2223
TensorFromScalar,
2324
)
2425

@@ -185,3 +186,11 @@ def tensorfromscalar(x):
185186
return torch.as_tensor(x)
186187

187188
return tensorfromscalar
189+
190+
191+
@pytorch_funcify.register(Split)
192+
def pytorch_funcify_Split(op, node, **kwargs):
193+
def inner_fn(x, dim, split_amounts):
194+
return x.split(split_amounts.tolist(), dim=dim.item())
195+
196+
return inner_fn

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
66
from pytensor.scalar.basic import (
77
Cast,
8+
Invert,
89
ScalarOp,
910
)
1011
from pytensor.scalar.loop import ScalarLoop
1112
from pytensor.scalar.math import Softplus
1213

1314

15+
@pytorch_funcify.register(Invert)
16+
def pytorch_funcify_invert(op, node, **kwargs):
17+
return torch.bitwise_not
18+
19+
1420
@pytorch_funcify.register(ScalarOp)
1521
def pytorch_funcify_ScalarOp(op, node, **kwargs):
1622
"""Return pytorch function that implements the same computation as the Scalar Op.

0 commit comments

Comments
 (0)