Skip to content

Commit 3c4f73d

Browse files
author
Ian Schweer
committed
Add IfElse
1 parent 426931b commit 3c4f73d

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pytensor.compile.ops import DeepCopyOp
77
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.ifelse import IfElse
89
from pytensor.link.utils import fgraph_to_python
910
from pytensor.raise_op import CheckAndRaise
1011
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
@@ -124,6 +125,7 @@ def eye(N, M, k):
124125
return eye
125126

126127

128+
127129
@pytorch_funcify.register(MakeVector)
128130
def pytorch_funcify_MakeVector(op, **kwargs):
129131
torch_dtype = getattr(torch, op.dtype)
@@ -132,3 +134,14 @@ def makevector(*x):
132134
return torch.tensor(x, dtype=torch_dtype)
133135

134136
return makevector
137+
138+
139+
@pytorch_funcify.register(IfElse)
140+
def pytorch_funcify_IfElse(op, **kwargs):
141+
n_outs = op.n_outs
142+
assert n_outs == 1
143+
144+
def ifelse(cond, *args, n_outs=n_outs):
145+
return torch.where(cond, *args)
146+
147+
return ifelse

tests/link/pytorch/test_basic.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.fg import FunctionGraph
14-
from pytensor.graph.op import Op
14+
from pytensor.graph.op import Op, get_test_value
15+
from pytensor.ifelse import ifelse
1516
from pytensor.raise_op import CheckAndRaise
1617
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1718
from pytensor.tensor.type import matrix, scalar, vector
@@ -301,3 +302,20 @@ def test_pytorch_MakeVector():
301302
x_fg = FunctionGraph([], [x])
302303

303304
compare_pytorch_and_py(x_fg, [])
305+
306+
307+
def test_pytorch_ifelse():
308+
true_vals = np.r_[1, 2, 3]
309+
false_vals = np.r_[-1, -2, -3]
310+
311+
x = ifelse(np.array(True), true_vals, false_vals)
312+
x_fg = FunctionGraph([], [x])
313+
314+
compare_pytorch_and_py(x_fg, [])
315+
316+
a = scalar("a")
317+
a.tag.test_value = np.array(0.2, dtype=config.floatX)
318+
x = ifelse(a < 0.5, true_vals, false_vals)
319+
x_fg = FunctionGraph([a], [x]) # I.e. False
320+
321+
compare_pytorch_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])

0 commit comments

Comments
 (0)