From 647edf39b97791b3a53b367f8404a4f278e080fb Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 22 Jun 2024 15:58:46 +0530 Subject: [PATCH 1/8] updated where to allow for case with only condition --- pytensor/tensor/basic.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 0a92bac106..3ba7f590b4 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -760,7 +760,12 @@ def switch(cond, ift, iff): """if cond then ift else iff""" -where = switch +# where = switch +def where(cond, ift, iff): + if ift is not None and iff is not None: + return switch(cond, ift, iff) + else: + pass @scalar_elemwise From 1c66378e295491ab0486ed8ab02c55b66369f17e Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 22 Jun 2024 16:02:38 +0530 Subject: [PATCH 2/8] add code for case with only condition --- pytensor/tensor/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 3ba7f590b4..245049c445 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -765,7 +765,7 @@ def where(cond, ift, iff): if ift is not None and iff is not None: return switch(cond, ift, iff) else: - pass + return cond.nonzero() @scalar_elemwise From 4174fa3d265873bd87a3ed5f64b6bc5d2506d97c Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 22 Jun 2024 22:47:33 +0530 Subject: [PATCH 3/8] updated cases and added test for only condition --- pytensor/tensor/basic.py | 12 ++++++++---- tests/tensor/test_basic.py | 9 +++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 245049c445..0bc7588f5e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -760,12 +760,16 @@ def switch(cond, ift, iff): """if cond then ift else iff""" -# where = switch -def where(cond, ift, iff): - if ift is not None and iff is not None: +def where(cond, ift=None, iff=None): + # Raise an error if only one of ift or iff is passed + if (ift is None and iff is not None) or (ift is not None and iff is None): + raise Exception("Either both or none of the parameters should be passed") + # Normal switch incase both arguements are passed + elif ift is not None and iff is not None: return switch(cond, ift, iff) + # Add case when only condition is passed else: - return cond.nonzero() + return as_tensor(cond).nonzero() @scalar_elemwise diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 08e969356a..8ec425eecd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -87,6 +87,7 @@ triu_indices, triu_indices_from, vertical_stack, + where, zeros_like, ) from pytensor.tensor.blockwise import Blockwise @@ -4608,3 +4609,11 @@ def core_np(x, y): vectorize_pt(x_test, y_test), vectorize_np(x_test, y_test), ) + + +def test_where_for_only_condition(): + a = np.array([1, 2, 3, 4, 5]) + cond = a <= 3 + pt_result = where(cond)[0].eval() + np_result = np.where(cond)[0] + np.testing.assert_allclose(pt_result, np_result) From b23c4445f625ed81b96551bdec669b79a8bb1277 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 23 Jun 2024 11:57:04 +0530 Subject: [PATCH 4/8] cleaned code and working tests --- pytensor/tensor/basic.py | 10 +++++----- tests/tensor/test_basic.py | 19 +++++++++++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 0bc7588f5e..9aeb7b1c39 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -761,15 +761,15 @@ def switch(cond, ift, iff): def where(cond, ift=None, iff=None): - # Raise an error if only one of ift or iff is passed - if (ift is None and iff is not None) or (ift is not None and iff is None): - raise Exception("Either both or none of the parameters should be passed") # Normal switch incase both arguements are passed - elif ift is not None and iff is not None: + if ift is not None and iff is not None: return switch(cond, ift, iff) # Add case when only condition is passed - else: + elif ift is None and iff is None: return as_tensor(cond).nonzero() + # Raise an error if only one arguement is passed + else: + raise Exception("Either both or none of the parameters should be passed") @scalar_elemwise diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 8ec425eecd..68de9dfd72 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4611,9 +4611,20 @@ def core_np(x, y): ) -def test_where_for_only_condition(): +@pytest.mark.parametrize( + "ift, iff", + [(None, None), (7, 10), (7, None)], + ids=["both none", "both valid", "one none"], +) +def test_where_for_only_condition(ift, iff): a = np.array([1, 2, 3, 4, 5]) - cond = a <= 3 - pt_result = where(cond)[0].eval() - np_result = np.where(cond)[0] + cond = a >= 3 + if ift is None and iff is None: + pt_where = where(cond) + np_result = np.where(cond) + else: + pt_where = where(cond, ift, iff) + np_result = np.where(cond, ift, iff) + f_test = function([], pt_where) + pt_result = f_test() np.testing.assert_allclose(pt_result, np_result) From 55bad54b750f655827f3ab3b92cee45643448593 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 23 Jun 2024 22:21:40 +0530 Subject: [PATCH 5/8] updated test for checking exception --- tests/tensor/test_basic.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 68de9dfd72..19aa61982e 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4618,13 +4618,14 @@ def core_np(x, y): ) def test_where_for_only_condition(ift, iff): a = np.array([1, 2, 3, 4, 5]) - cond = a >= 3 - if ift is None and iff is None: - pt_where = where(cond) + cond = a < 3 + if ift is not None and iff is not None: + pt_result = function([], where(cond, ift, iff))() + np_result = np.where(cond, ift, iff) + np.testing.assert_allclose(pt_result, np_result) + elif ift is None and iff is None: + pt_result = function([], where(cond))() np_result = np.where(cond) + np.testing.assert_allclose(pt_result, np_result) else: - pt_where = where(cond, ift, iff) - np_result = np.where(cond, ift, iff) - f_test = function([], pt_where) - pt_result = f_test() - np.testing.assert_allclose(pt_result, np_result) + pytest.raises(Exception, where, cond, ift) From 14a9ca70cf6a9c8d32ca3b095c57b979d31917e0 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 24 Jun 2024 12:42:02 +0530 Subject: [PATCH 6/8] added docstring and test with xfail --- pytensor/tensor/basic.py | 21 ++++++++++++++++++--- tests/tensor/test_basic.py | 15 ++++----------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 9aeb7b1c39..5a52f18fea 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -761,13 +761,28 @@ def switch(cond, ift, iff): def where(cond, ift=None, iff=None): - # Normal switch incase both arguements are passed + """ + where(condition, [ift, iff]) + Return elements chosen from `ift` or `iff` depending on `condition`. + + Note: When only condition is provided, this function is a shorthand for `as_tensor(condition).nonzero()`. + + Parameters + ---------- + condition : tensor_like, bool + Where True, yield `ift`, otherwise yield `iff`. + x, y : tensor_like + Values from which to choose. + + Returns + ------- + out : TensorVariable + A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere. + """ if ift is not None and iff is not None: return switch(cond, ift, iff) - # Add case when only condition is passed elif ift is None and iff is None: return as_tensor(cond).nonzero() - # Raise an error if only one arguement is passed else: raise Exception("Either both or none of the parameters should be passed") diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 19aa61982e..5c5d0416a7 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4613,19 +4613,12 @@ def core_np(x, y): @pytest.mark.parametrize( "ift, iff", - [(None, None), (7, 10), (7, None)], + [(None, None), (7, 10), pytest.param(7, None, marks=[pytest.mark.xfail])], ids=["both none", "both valid", "one none"], ) def test_where_for_only_condition(ift, iff): a = np.array([1, 2, 3, 4, 5]) cond = a < 3 - if ift is not None and iff is not None: - pt_result = function([], where(cond, ift, iff))() - np_result = np.where(cond, ift, iff) - np.testing.assert_allclose(pt_result, np_result) - elif ift is None and iff is None: - pt_result = function([], where(cond))() - np_result = np.where(cond) - np.testing.assert_allclose(pt_result, np_result) - else: - pytest.raises(Exception, where, cond, ift) + np_result = np.where(*[x for x in [cond, ift, iff] if x is not None]) + pt_result = function([], where(cond, ift, iff))() + np.testing.assert_allclose(np_result, pt_result) From 3e601c6daa8e1bd32bde64a91e563c59305b0a52 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 24 Jun 2024 13:04:57 +0530 Subject: [PATCH 7/8] added kwargs and removed parametrised tests --- pytensor/tensor/basic.py | 4 ++-- tests/tensor/test_basic.py | 26 +++++++++++++++----------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5a52f18fea..a2f816bc68 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -760,7 +760,7 @@ def switch(cond, ift, iff): """if cond then ift else iff""" -def where(cond, ift=None, iff=None): +def where(cond, ift=None, iff=None, **kwargs): """ where(condition, [ift, iff]) Return elements chosen from `ift` or `iff` depending on `condition`. @@ -784,7 +784,7 @@ def where(cond, ift=None, iff=None): elif ift is None and iff is None: return as_tensor(cond).nonzero() else: - raise Exception("Either both or none of the parameters should be passed") + raise ValueError("either both or neither of ift and iff should be given") @scalar_elemwise diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 5c5d0416a7..ed8909944a 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4611,14 +4611,18 @@ def core_np(x, y): ) -@pytest.mark.parametrize( - "ift, iff", - [(None, None), (7, 10), pytest.param(7, None, marks=[pytest.mark.xfail])], - ids=["both none", "both valid", "one none"], -) -def test_where_for_only_condition(ift, iff): - a = np.array([1, 2, 3, 4, 5]) - cond = a < 3 - np_result = np.where(*[x for x in [cond, ift, iff] if x is not None]) - pt_result = function([], where(cond, ift, iff))() - np.testing.assert_allclose(np_result, pt_result) +def test_where(): + a = np.arange(10) + cond = a < 5 + ift = np.pi + iff = np.e + # Test for all 3 inputs + np.testing.assert_allclose(np.where(cond, ift, iff), where(cond, ift, iff).eval()) + + # Test for only condition input + for np_output, pt_output in zip(np.where(cond), where(cond)): + np.testing.assert_allclose(np_output, pt_output.eval()) + + # Test for error + with pytest.raises(ValueError, match="either both"): + where(cond, ift) From 49153cd47c02ae6d27a4978e8d91e946794f560e Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 24 Jun 2024 13:08:47 +0530 Subject: [PATCH 8/8] propagate kwargs --- pytensor/tensor/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index a2f816bc68..135433a0ab 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -780,9 +780,9 @@ def where(cond, ift=None, iff=None, **kwargs): A tensor with elements from `ift` where `condition` is True, and elements from `iff` elsewhere. """ if ift is not None and iff is not None: - return switch(cond, ift, iff) + return switch(cond, ift, iff, **kwargs) elif ift is None and iff is None: - return as_tensor(cond).nonzero() + return as_tensor(cond).nonzero(**kwargs) else: raise ValueError("either both or neither of ift and iff should be given")