From d623ea0c1de11824cfa49c2e4d5b269ed3b8dabe Mon Sep 17 00:00:00 2001 From: JustinZhengBC Date: Sat, 8 Dec 2018 19:47:15 -0800 Subject: [PATCH] BUG-16983 fix df.where with extension dtypes --- doc/source/whatsnew/v0.24.0.rst | 1 + pandas/core/internals/blocks.py | 2 ++ pandas/tests/frame/test_indexing.py | 9 +++++++++ 3 files changed, 12 insertions(+) diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index 0b2b526dfe9e7..0155e70384cfd 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -1418,6 +1418,7 @@ Indexing - Bug in :func:`Index.union` and :func:`Index.intersection` where name of the ``Index`` of the result was not computed correctly for certain cases (:issue:`9943`, :issue:`9862`) - Bug in :class:`Index` slicing with boolean :class:`Index` may raise ``TypeError`` (:issue:`22533`) - Bug in ``PeriodArray.__setitem__`` when accepting slice and list-like value (:issue:`23978`) +- Bug in :func:`DataFrame.where` where a ``ValueError`` would raise when a column had an extension dtype (:issue:`16983`) Missing ^^^^^^^ diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 51c47a81f8e2f..b68f187d71e41 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1341,6 +1341,8 @@ def func(cond, values, other): return values values, other = self._try_coerce_args(values, other) + if isinstance(values, ExtensionArray): + values = values.get_values().reshape(-1, 1) # GH 16983 try: return self._try_coerce_result(expressions.where( diff --git a/pandas/tests/frame/test_indexing.py b/pandas/tests/frame/test_indexing.py index b95dad422e90a..cc7a4fb28aace 100644 --- a/pandas/tests/frame/test_indexing.py +++ b/pandas/tests/frame/test_indexing.py @@ -3089,6 +3089,15 @@ def test_where_tz_values(self, tz_naive_fixture): result = df1.where(mask, df2) assert_frame_equal(exp, result) + def test_where_extension_dtypes_with_na(self): + from pandas.core.sparse.api import SparseDataFrame + sdf = DataFrame(SparseDataFrame([1, None])) + cdf = DataFrame([1, None]).astype('category') + df = DataFrame([1, None]) + + assert_frame_equal(df.where(df.isna()), sdf.where(sdf.isna())) + assert_frame_equal(df.where(df.isna()), cdf.where(cdf.isna())) + def test_mask(self): df = DataFrame(np.random.randn(5, 3)) cond = df > 0