Skip to content

Commit 523a561

Browse files
author
Ian Schweer
committed
Add IfElse
1 parent 308bc01 commit 523a561

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pytensor.compile.ops import DeepCopyOp
66
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.ifelse import IfElse
78
from pytensor.link.utils import fgraph_to_python
89
from pytensor.raise_op import CheckAndRaise
910
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join
@@ -116,3 +117,14 @@ def eye(N, M, k):
116117
return zeros
117118

118119
return eye
120+
121+
122+
@pytorch_funcify.register(IfElse)
123+
def pytorch_funcify_IfElse(op, **kwargs):
124+
n_outs = op.n_outs
125+
assert n_outs == 1
126+
127+
def ifelse(cond, *args, n_outs=n_outs):
128+
return torch.where(cond, *args)
129+
130+
return ifelse

tests/link/pytorch/test_basic.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.fg import FunctionGraph
14-
from pytensor.graph.op import Op
14+
from pytensor.graph.op import Op, get_test_value
15+
from pytensor.ifelse import ifelse
1516
from pytensor.raise_op import CheckAndRaise
1617
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1718
from pytensor.tensor.type import matrix, scalar, vector
@@ -294,3 +295,20 @@ def test_eye(dtype):
294295
for _M in range(1, 6):
295296
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
296297
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))
298+
299+
300+
def test_pytorch_ifelse():
301+
true_vals = np.r_[1, 2, 3]
302+
false_vals = np.r_[-1, -2, -3]
303+
304+
x = ifelse(np.array(True), true_vals, false_vals)
305+
x_fg = FunctionGraph([], [x])
306+
307+
compare_pytorch_and_py(x_fg, [])
308+
309+
a = scalar("a")
310+
a.tag.test_value = np.array(0.2, dtype=config.floatX)
311+
x = ifelse(a < 0.5, true_vals, false_vals)
312+
x_fg = FunctionGraph([a], [x]) # I.e. False
313+
314+
compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])

0 commit comments

Comments
 (0)