Skip to content

Commit 7159215

Browse files
authored
Update tensor.where to allow for case with only condition (#844)
1 parent d3bd1f1 commit 7159215

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pytensor/tensor/basic.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,31 @@ def switch(cond, ift, iff):
760760
"""if cond then ift else iff"""
761761

762762

763-
where = switch
763+
def where(cond, ift=None, iff=None, **kwargs):
764+
"""
765+
where(condition, [ift, iff])
766+
Return elements chosen from `ift` or `iff` depending on `condition`.
767+
768+
Note: When only condition is provided, this function is a shorthand for `as_tensor(condition).nonzero()`.
769+
770+
Parameters
771+
----------
772+
condition : tensor_like, bool
773+
Where True, yield `ift`, otherwise yield `iff`.
774+
x, y : tensor_like
775+
Values from which to choose.
776+
777+
Returns
778+
-------
779+
out : TensorVariable
780+
A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere.
781+
"""
782+
if ift is not None and iff is not None:
783+
return switch(cond, ift, iff, **kwargs)
784+
elif ift is None and iff is None:
785+
return as_tensor(cond).nonzero(**kwargs)
786+
else:
787+
raise ValueError("either both or neither of ift and iff should be given")
764788

765789

766790
@scalar_elemwise

tests/tensor/test_basic.py

+18
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
triu_indices,
8888
triu_indices_from,
8989
vertical_stack,
90+
where,
9091
zeros_like,
9192
)
9293
from pytensor.tensor.blockwise import Blockwise
@@ -4608,3 +4609,20 @@ def core_np(x, y):
46084609
vectorize_pt(x_test, y_test),
46094610
vectorize_np(x_test, y_test),
46104611
)
4612+
4613+
4614+
def test_where():
4615+
a = np.arange(10)
4616+
cond = a < 5
4617+
ift = np.pi
4618+
iff = np.e
4619+
# Test for all 3 inputs
4620+
np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval())
4621+
4622+
# Test for only condition input
4623+
for np_output, pt_output in zip(np.where(cond), where(cond)):
4624+
np.testing.assert_allclose(np_output, pt_output.eval())
4625+
4626+
# Test for error
4627+
with pytest.raises(ValueError, match="either both"):
4628+
where(cond, ift)

0 commit comments

Comments
 (0)