diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 0a92bac106..135433a0ab 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -760,7 +760,31 @@ def switch(cond, ift, iff): """if cond then ift else iff""" -where = switch +def where(cond, ift=None, iff=None, **kwargs): + """ + where(condition, [ift, iff]) + Return elements chosen from `ift` or `iff` depending on `condition`. + + Note: When only condition is provided, this function is a shorthand for `as_tensor(condition).nonzero()`. + + Parameters + ---------- + condition : tensor_like, bool + Where True, yield `ift`, otherwise yield `iff`. + x, y : tensor_like + Values from which to choose. + + Returns + ------- + out : TensorVariable + A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere. + """ + if ift is not None and iff is not None: + return switch(cond, ift, iff, **kwargs) + elif ift is None and iff is None: + return as_tensor(cond).nonzero(**kwargs) + else: + raise ValueError("either both or neither of ift and iff should be given") @scalar_elemwise diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 08e969356a..ed8909944a 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -87,6 +87,7 @@ triu_indices, triu_indices_from, vertical_stack, + where, zeros_like, ) from pytensor.tensor.blockwise import Blockwise @@ -4608,3 +4609,20 @@ def core_np(x, y): vectorize_pt(x_test, y_test), vectorize_np(x_test, y_test), ) + + +def test_where(): + a = np.arange(10) + cond = a < 5 + ift = np.pi + iff = np.e + # Test for all 3 inputs + np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval()) + + # Test for only condition input + for np_output, pt_output in zip(np.where(cond), where(cond)): + np.testing.assert_allclose(np_output, pt_output.eval()) + + # Test for error + with pytest.raises(ValueError, match="either both"): + where(cond, ift)