diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..6cf4f29aab 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -19,6 +19,7 @@ Eye, Join, MakeVector, + Split, TensorFromScalar, ) @@ -185,3 +186,11 @@ def tensorfromscalar(x): return torch.as_tensor(x) return tensorfromscalar + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op, node, **kwargs): + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) + + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 1416e58f55..2505ef655a 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,11 +5,17 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Invert, ScalarOp, ) from pytensor.scalar.math import Softplus +@pytorch_funcify.register(Invert) +def pytorch_funcify_invert(op, node, **kwargs): + return torch.bitwise_not + + @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op.