Skip to content

Commit c477732

Browse files
Ch0ronomatoIan Schweer
and
Ian Schweer
authored
Torch dispatch for scipy-like functions and Softplus (#1066)
* Allow for scipy module resolution * Add softplus * Add tests * Allow scipy scalar handling * Check for scipy in elemwise --------- Co-authored-by: Ian Schweer <[email protected]>
1 parent ae66e82 commit c477732

File tree

3 files changed

+47
-4
lines changed

3 files changed

+47
-4
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import importlib
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
@@ -11,12 +13,26 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1113
scalar_op = op.scalar_op
1214
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1315

14-
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
16+
def check_special_scipy(func_name):
17+
if "scipy." not in func_name:
18+
return False
19+
loc = func_name.split(".")[1:]
20+
try:
21+
mod = importlib.import_module(".".join(loc[:-1]), "torch")
22+
return getattr(mod, loc[-1], False)
23+
except ImportError:
24+
return False
25+
26+
if hasattr(scalar_op, "nfunc_spec") and (
27+
hasattr(torch, scalar_op.nfunc_spec[0])
28+
or check_special_scipy(scalar_op.nfunc_spec[0])
29+
):
1530
# torch can handle this scalar
1631
# broadcast, we'll let it.
1732
def elemwise_fn(*inputs):
1833
Elemwise._check_runtime_broadcast(node, inputs)
1934
return base_fn(*inputs)
35+
2036
else:
2137

2238
def elemwise_fn(*inputs):

pytensor/link/pytorch/dispatch/scalar.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import importlib
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
46
from pytensor.scalar.basic import (
57
Cast,
68
ScalarOp,
79
)
10+
from pytensor.scalar.math import Softplus
811

912

1013
@pytorch_funcify.register(ScalarOp)
@@ -19,9 +22,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs):
1922
if nfunc_spec is None:
2023
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
2124

22-
func_name = nfunc_spec[0]
25+
func_name = nfunc_spec[0].replace("scipy.", "")
2326

24-
pytorch_func = getattr(torch, func_name)
27+
if "." in func_name:
28+
loc = func_name.split(".")
29+
mod = importlib.import_module(".".join(["torch", *loc[:-1]]))
30+
pytorch_func = getattr(mod, loc[-1])
31+
else:
32+
pytorch_func = getattr(torch, func_name)
2533

2634
if len(node.inputs) > op.nfunc_spec[1]:
2735
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
@@ -49,3 +57,8 @@ def cast(x):
4957
return x.to(dtype=dtype)
5058

5159
return cast
60+
61+
62+
@pytorch_funcify.register(Softplus)
63+
def pytorch_funcify_Softplus(op, node, **kwargs):
64+
return torch.nn.Softplus()

tests/link/pytorch/test_basic.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pytensor.ifelse import ifelse
1818
from pytensor.link.pytorch.linker import PytorchLinker
1919
from pytensor.raise_op import CheckAndRaise
20-
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
20+
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
2121
from pytensor.tensor.type import matrices, matrix, scalar, vector
2222

2323

@@ -374,3 +374,17 @@ def inner_fn(x):
374374
f = function([x], out, mode="PYTORCH")
375375
f(torch.ones(3))
376376
assert "inner_fn" not in dir(m), "function call reference leaked"
377+
378+
379+
def test_pytorch_scipy():
380+
x = vector("a", shape=(3,))
381+
out = expit(x)
382+
f = FunctionGraph([x], [out])
383+
compare_pytorch_and_py(f, [np.random.rand(3)])
384+
385+
386+
def test_pytorch_softplus():
387+
x = vector("a", shape=(3,))
388+
out = softplus(x)
389+
f = FunctionGraph([x], [out])
390+
compare_pytorch_and_py(f, [np.random.rand(3)])

0 commit comments

Comments
 (0)