From 5037932f78ebdee297c26eaf65585e7290b4ef32 Mon Sep 17 00:00:00 2001 From: gfyoung Date: Wed, 15 Feb 2017 12:02:03 -0500 Subject: [PATCH] BUG: Accept generic array-like in .where --- doc/source/whatsnew/v0.20.0.txt | 1 + pandas/core/generic.py | 37 +++++---- pandas/indexes/base.py | 2 +- pandas/tests/frame/test_indexing.py | 89 ++++++++++++++++++++++ pandas/tests/indexes/common.py | 12 +++ pandas/tests/indexes/period/test_period.py | 11 ++- pandas/tests/indexes/test_category.py | 12 ++- pandas/tests/indexes/test_multi.py | 9 +++ pandas/tests/series/test_indexing.py | 54 +++++++++++++ 9 files changed, 208 insertions(+), 19 deletions(-) diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index fa24c973a7549..2d70d79509e77 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -548,6 +548,7 @@ Bug Fixes +- Bug in ``Series.where()`` and ``DataFrame.where()`` where array-like conditionals were being rejected (:issue:`15414`) - Bug in ``Series`` construction with a datetimetz (:issue:`14928`) - Bug in compat for passing long integers to ``Timestamp.replace`` (:issue:`15030`) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 76fbb9884753d..10914e4baf61b 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4732,17 +4732,28 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None, cond, _ = cond.align(self, join='right', broadcast_axis=1) else: if not hasattr(cond, 'shape'): - raise ValueError('where requires an ndarray like object for ' - 'its condition') + cond = np.asanyarray(cond) if cond.shape != self.shape: raise ValueError('Array conditional must be same shape as ' 'self') cond = self._constructor(cond, **self._construct_axes_dict()) - if inplace: - cond = -(cond.fillna(True).astype(bool)) + fill_value = True if inplace else False + cond = cond.fillna(fill_value) + + msg = "Boolean array expected for the condition, not {dtype}" + + if not isinstance(cond, pd.DataFrame): + # This is a single-dimensional object. + if not is_bool_dtype(cond): + raise ValueError(msg.format(dtype=cond.dtype)) else: - cond = cond.fillna(False).astype(bool) + for dt in cond.dtypes: + if not is_bool_dtype(dt): + raise ValueError(msg.format(dtype=dt)) + + cond = cond.astype(bool, copy=False) + cond = -cond if inplace else cond # try to align try_quick = True @@ -4891,26 +4902,20 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None, Parameters ---------- - cond : boolean %(klass)s, array or callable + cond : boolean %(klass)s, array-like, or callable If cond is callable, it is computed on the %(klass)s and - should return boolean %(klass)s or array. - The callable must not change input %(klass)s - (though pandas doesn't check it). + should return boolean %(klass)s or array. The callable must + not change input %(klass)s (though pandas doesn't check it). .. versionadded:: 0.18.1 - A callable can be used as cond. - other : scalar, %(klass)s, or callable If other is callable, it is computed on the %(klass)s and - should return scalar or %(klass)s. - The callable must not change input %(klass)s - (though pandas doesn't check it). + should return scalar or %(klass)s. The callable must not + change input %(klass)s (though pandas doesn't check it). .. versionadded:: 0.18.1 - A callable can be used as other. - inplace : boolean, default False Whether to perform the operation in place on the data axis : alignment axis if needed, default None diff --git a/pandas/indexes/base.py b/pandas/indexes/base.py index 4837fc0d7438c..dcbcccdfcd610 100644 --- a/pandas/indexes/base.py +++ b/pandas/indexes/base.py @@ -573,7 +573,7 @@ def repeat(self, repeats, *args, **kwargs): Parameters ---------- - cond : boolean same length as self + cond : boolean array-like with the same length as self other : scalar, or array-like """ diff --git a/pandas/tests/frame/test_indexing.py b/pandas/tests/frame/test_indexing.py index c06faa75ed346..18fb17b98570a 100644 --- a/pandas/tests/frame/test_indexing.py +++ b/pandas/tests/frame/test_indexing.py @@ -2479,6 +2479,95 @@ def _check_set(df, cond, check_dtypes=True): expected = df[df['a'] == 1].reindex(df.index) assert_frame_equal(result, expected) + def test_where_array_like(self): + # see gh-15414 + klasses = [list, tuple, np.array] + + df = DataFrame({'a': [1, 2, 3]}) + cond = [[False], [True], [True]] + expected = DataFrame({'a': [np.nan, 2, 3]}) + + for klass in klasses: + result = df.where(klass(cond)) + assert_frame_equal(result, expected) + + df['b'] = 2 + expected['b'] = [2, np.nan, 2] + cond = [[False, True], [True, False], [True, True]] + + for klass in klasses: + result = df.where(klass(cond)) + assert_frame_equal(result, expected) + + def test_where_invalid_input(self): + # see gh-15414: only boolean arrays accepted + df = DataFrame({'a': [1, 2, 3]}) + msg = "Boolean array expected for the condition" + + conds = [ + [[1], [0], [1]], + Series([[2], [5], [7]]), + DataFrame({'a': [2, 5, 7]}), + [["True"], ["False"], ["True"]], + [[Timestamp("2017-01-01")], + [pd.NaT], [Timestamp("2017-01-02")]] + ] + + for cond in conds: + with tm.assertRaisesRegexp(ValueError, msg): + df.where(cond) + + df['b'] = 2 + conds = [ + [[0, 1], [1, 0], [1, 1]], + Series([[0, 2], [5, 0], [4, 7]]), + [["False", "True"], ["True", "False"], + ["True", "True"]], + DataFrame({'a': [2, 5, 7], 'b': [4, 8, 9]}), + [[pd.NaT, Timestamp("2017-01-01")], + [Timestamp("2017-01-02"), pd.NaT], + [Timestamp("2017-01-03"), Timestamp("2017-01-03")]] + ] + + for cond in conds: + with tm.assertRaisesRegexp(ValueError, msg): + df.where(cond) + + def test_where_dataframe_col_match(self): + df = DataFrame([[1, 2, 3], [4, 5, 6]]) + cond = DataFrame([[True, False, True], [False, False, True]]) + + out = df.where(cond) + expected = DataFrame([[1.0, np.nan, 3], [np.nan, np.nan, 6]]) + tm.assert_frame_equal(out, expected) + + cond.columns = ["a", "b", "c"] # Columns no longer match. + msg = "Boolean array expected for the condition" + with tm.assertRaisesRegexp(ValueError, msg): + df.where(cond) + + def test_where_ndframe_align(self): + msg = "Array conditional must be same shape as self" + df = DataFrame([[1, 2, 3], [4, 5, 6]]) + + cond = [True] + with tm.assertRaisesRegexp(ValueError, msg): + df.where(cond) + + expected = DataFrame([[1, 2, 3], [np.nan, np.nan, np.nan]]) + + out = df.where(Series(cond)) + tm.assert_frame_equal(out, expected) + + cond = np.array([False, True, False, True]) + with tm.assertRaisesRegexp(ValueError, msg): + df.where(cond) + + expected = DataFrame([[np.nan, np.nan, np.nan], [4, 5, 6]]) + + out = df.where(Series(cond)) + tm.assert_frame_equal(out, expected) + def test_where_bug(self): # GH 2793 diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index 81ad0524807f3..7b39a33266ffa 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -497,6 +497,18 @@ def test_where(self): result = i.where(cond) tm.assert_index_equal(result, expected) + def test_where_array_like(self): + i = self.create_index() + + _nan = i._na_value + cond = [False] + [True] * (len(i) - 1) + klasses = [list, tuple, np.array, pd.Series] + expected = pd.Index([_nan] + i[1:].tolist(), dtype=i.dtype) + + for klass in klasses: + result = i.where(klass(cond)) + tm.assert_index_equal(result, expected) + def test_setops_errorcases(self): for name, idx in compat.iteritems(self.indices): # # non-iterable input diff --git a/pandas/tests/indexes/period/test_period.py b/pandas/tests/indexes/period/test_period.py index 6a8128bb8985f..b80ab6feeeb23 100644 --- a/pandas/tests/indexes/period/test_period.py +++ b/pandas/tests/indexes/period/test_period.py @@ -89,13 +89,22 @@ def test_where(self): expected = i tm.assert_index_equal(result, expected) - i2 = i.copy() i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(), freq='D') result = i.where(notnull(i2)) expected = i2 tm.assert_index_equal(result, expected) + def test_where_array_like(self): + i = self.create_index() + cond = [False] + [True] * (len(i) - 1) + klasses = [list, tuple, np.array, Series] + expected = pd.PeriodIndex([pd.NaT] + i[1:].tolist(), freq='D') + + for klass in klasses: + result = i.where(klass(cond)) + tm.assert_index_equal(result, expected) + def test_where_other(self): i = self.create_index() diff --git a/pandas/tests/indexes/test_category.py b/pandas/tests/indexes/test_category.py index 6b6885c082533..64a0e71bd5ace 100644 --- a/pandas/tests/indexes/test_category.py +++ b/pandas/tests/indexes/test_category.py @@ -240,13 +240,23 @@ def test_where(self): expected = i tm.assert_index_equal(result, expected) - i2 = i.copy() i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(), categories=i.categories) result = i.where(notnull(i2)) expected = i2 tm.assert_index_equal(result, expected) + def test_where_array_like(self): + i = self.create_index() + cond = [False] + [True] * (len(i) - 1) + klasses = [list, tuple, np.array, pd.Series] + expected = pd.CategoricalIndex([np.nan] + i[1:].tolist(), + categories=i.categories) + + for klass in klasses: + result = i.where(klass(cond)) + tm.assert_index_equal(result, expected) + def test_append(self): ci = self.create_index() diff --git a/pandas/tests/indexes/test_multi.py b/pandas/tests/indexes/test_multi.py index 5611492b4af1b..80ff67ab3d043 100644 --- a/pandas/tests/indexes/test_multi.py +++ b/pandas/tests/indexes/test_multi.py @@ -88,6 +88,15 @@ def f(): self.assertRaises(NotImplementedError, f) + def test_where_array_like(self): + i = MultiIndex.from_tuples([('A', 1), ('A', 2)]) + klasses = [list, tuple, np.array, pd.Series] + cond = [False, True] + + for klass in klasses: + f = lambda: i.where(klass(cond)) + self.assertRaises(NotImplementedError, f) + def test_repeat(self): reps = 2 numbers = [1, 2, 3] diff --git a/pandas/tests/series/test_indexing.py b/pandas/tests/series/test_indexing.py index a20cb8324d2a3..8a2cc53b42938 100644 --- a/pandas/tests/series/test_indexing.py +++ b/pandas/tests/series/test_indexing.py @@ -1193,6 +1193,60 @@ def f(): expected = Series(np.nan, index=[9]) assert_series_equal(result, expected) + def test_where_array_like(self): + # see gh-15414 + s = Series([1, 2, 3]) + cond = [False, True, True] + expected = Series([np.nan, 2, 3]) + klasses = [list, tuple, np.array, Series] + + for klass in klasses: + result = s.where(klass(cond)) + assert_series_equal(result, expected) + + def test_where_invalid_input(self): + # see gh-15414: only boolean arrays accepted + s = Series([1, 2, 3]) + msg = "Boolean array expected for the condition" + + conds = [ + [1, 0, 1], + Series([2, 5, 7]), + ["True", "False", "True"], + [Timestamp("2017-01-01"), + pd.NaT, Timestamp("2017-01-02")] + ] + + for cond in conds: + with tm.assertRaisesRegexp(ValueError, msg): + s.where(cond) + + msg = "Array conditional must be same shape as self" + with tm.assertRaisesRegexp(ValueError, msg): + s.where([True]) + + def test_where_ndframe_align(self): + msg = "Array conditional must be same shape as self" + s = Series([1, 2, 3]) + + cond = [True] + with tm.assertRaisesRegexp(ValueError, msg): + s.where(cond) + + expected = Series([1, np.nan, np.nan]) + + out = s.where(Series(cond)) + tm.assert_series_equal(out, expected) + + cond = np.array([False, True, False, True]) + with tm.assertRaisesRegexp(ValueError, msg): + s.where(cond) + + expected = Series([np.nan, 2, np.nan]) + + out = s.where(Series(cond)) + tm.assert_series_equal(out, expected) + def test_where_setitem_invalid(self): # GH 2702