Skip to content

Commit 46fdc58

Browse files
Ch0ronomatoIan Schweer
and
Ian Schweer
authored
Add torch implementation of IfElse (#974)
Co-authored-by: Ian Schweer <[email protected]>
1 parent 8a6e407 commit 46fdc58

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pytensor/link/pytorch/dispatch/basic.py

+14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.builders import OpFromGraph
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.graph.fg import FunctionGraph
12+
from pytensor.ifelse import IfElse
1213
from pytensor.link.utils import fgraph_to_python
1314
from pytensor.raise_op import CheckAndRaise
1415
from pytensor.tensor.basic import (
@@ -153,6 +154,19 @@ def makevector(*x):
153154
return makevector
154155

155156

157+
@pytorch_funcify.register(IfElse)
158+
def pytorch_funcify_IfElse(op, **kwargs):
159+
n_outs = op.n_outs
160+
161+
def ifelse(cond, *true_and_false, n_outs=n_outs):
162+
if cond:
163+
return true_and_false[:n_outs]
164+
else:
165+
return true_and_false[n_outs:]
166+
167+
return ifelse
168+
169+
156170
@pytorch_funcify.register(OpFromGraph)
157171
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
158172
kwargs.pop("storage_map", None)

tests/link/pytorch/test_basic.py

+18
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.graph.basic import Apply
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.graph.op import Op
16+
from pytensor.ifelse import ifelse
1617
from pytensor.raise_op import CheckAndRaise
1718
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1819
from pytensor.tensor.type import matrices, matrix, scalar, vector
@@ -304,6 +305,23 @@ def test_pytorch_MakeVector():
304305
compare_pytorch_and_py(x_fg, [])
305306

306307

308+
def test_pytorch_ifelse():
309+
p1_vals = np.r_[1, 2, 3]
310+
p2_vals = np.r_[-1, -2, -3]
311+
312+
a = scalar("a")
313+
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
314+
x_fg = FunctionGraph([a], x)
315+
316+
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))
317+
318+
a = scalar("a")
319+
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
320+
x_fg = FunctionGraph([a], x)
321+
322+
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))
323+
324+
307325
def test_pytorch_OpFromGraph():
308326
x, y, z = matrices("xyz")
309327
ofg_1 = OpFromGraph([x, y], [x + y])

0 commit comments

Comments
 (0)