From 6769fbda4b89878517385e4395056afd910875d5 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 27 Dec 2020 20:04:55 -0800 Subject: [PATCH 1/4] CLN: NDFrame._where --- pandas/core/generic.py | 39 +++++++++++++++++++++++++-------- pandas/core/internals/blocks.py | 11 +--------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index bdb28c10a0ad2..f5c1f89d0cd8e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -89,9 +89,10 @@ import pandas as pd from pandas.core import arraylike, indexing, missing, nanops import pandas.core.algorithms as algos +from pandas.core.arrays import ExtensionArray from pandas.core.base import PandasObject, SelectionMixin import pandas.core.common as com -from pandas.core.construction import create_series_with_explicit_dtype +from pandas.core.construction import create_series_with_explicit_dtype, extract_array from pandas.core.flags import Flags from pandas.core.indexes import base as ibase from pandas.core.indexes.api import ( @@ -8786,6 +8787,9 @@ def _where( """ inplace = validate_bool_kwarg(inplace, "inplace") + if axis is not None: + axis = self._get_axis_number(axis) + # align the cond to same shape as myself cond = com.apply_if_callable(cond, self) if isinstance(cond, NDFrame): @@ -8825,22 +8829,39 @@ def _where( if other.ndim <= self.ndim: _, other = self.align( - other, join="left", axis=axis, level=level, fill_value=np.nan + other, + join="left", + axis=axis, + level=level, + fill_value=np.nan, + copy=False, ) # if we are NOT aligned, raise as we cannot where index - if axis is None and not all( - other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes) - ): + if axis is None and not other._indexed_same(self): raise InvalidIndexError + elif other.ndim < self.ndim: + # TODO(EA2D): avoid object-dtype cast in EA case GH#38729 + other = other._values + if axis == 0: + other = np.reshape(other, (-1, 1)) + if axis == 1: + other = np.reshape(other, (1, -1)) + + other = np.broadcast_to(other, self.shape) + # slice me out of the other else: raise NotImplementedError( "cannot align with a higher dimensional NDFrame" ) - if isinstance(other, np.ndarray): + if not isinstance(other, (MultiIndex, NDFrame)): + # mainly just catching Index here + other = extract_array(other, extract_numpy=True) + + if isinstance(other, (np.ndarray, ExtensionArray)): if other.shape != self.shape: @@ -8885,10 +8906,10 @@ def _where( else: align = self._get_axis_number(axis) == 1 - if align and isinstance(other, NDFrame): - other = other.reindex(self._info_axis, axis=self._info_axis_number) if isinstance(cond, NDFrame): - cond = cond.reindex(self._info_axis, axis=self._info_axis_number) + cond = cond.reindex( + self._info_axis, axis=self._info_axis_number, copy=False + ) block_axis = self._get_block_manager_axis(axis) diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index ea1b8259eeadd..d42039e710666 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1064,9 +1064,7 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]: # If the default repeat behavior in np.putmask would go in the # wrong direction, then explicitly repeat and reshape new instead if getattr(new, "ndim", 0) >= 1: - if self.ndim - 1 == new.ndim and axis == 1: - new = np.repeat(new, new_values.shape[-1]).reshape(self.shape) - new = new.astype(new_values.dtype) + new = new.astype(new_values.dtype, copy=False) # we require exact matches between the len of the # values we are setting (or is compat). np.putmask @@ -1104,13 +1102,6 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]: new = new.T axis = new_values.ndim - axis - 1 - # Pseudo-broadcast - if getattr(new, "ndim", 0) >= 1: - if self.ndim - 1 == new.ndim: - new_shape = list(new.shape) - new_shape.insert(axis, 1) - new = new.reshape(tuple(new_shape)) - # operate column-by-column def f(mask, val, idx): From be8ce6874343e67cb930887bff90b37ab7ff0150 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 28 Dec 2020 09:53:13 -0800 Subject: [PATCH 2/4] TST: test with Series[Int64] --- pandas/tests/frame/indexing/test_where.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index acdb5726e4adb..87d2fd37ab023 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -653,3 +653,22 @@ def test_where_categorical_filtering(self): expected.loc[0, :] = np.nan tm.assert_equal(result, expected) + + def test_where_ea_other(self): + # GH#38729/GH#38742 + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + arr = pd.array([7, pd.NA, 9]) + ser = Series(arr) + mask = np.ones(df.shape, dtype=bool) + mask[1, :] = False + + # TODO: ideally we would get Int64 instead of object + result = df.where(mask, ser, axis=0) + expected = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}).astype(object) + tm.assert_frame_equal(result, expected) + + ser2 = Series(arr[:2], index=["A", "B"]) + expected = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]}) + expected["B"] = expected["B"].astype(object) + result = df.where(mask, ser2, axis=1) + tm.assert_frame_equal(result, expected) From f21f4482d87f75987a53f800ac7eeea1ab76d2cd Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 28 Dec 2020 09:54:17 -0800 Subject: [PATCH 3/4] whatsnew --- doc/source/whatsnew/v1.3.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index b41931a803053..c7a1e006e8b73 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -293,7 +293,7 @@ Sparse ExtensionArray ^^^^^^^^^^^^^^ - +- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with ExtensionArray dtype (:issue:`38729`) - - From 2bc397bb39551b32ccd1c7ada2181f6212b01a63 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 28 Dec 2020 10:30:29 -0800 Subject: [PATCH 4/4] if->elif --- pandas/core/generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f5c1f89d0cd8e..9f84447b7476d 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8846,7 +8846,7 @@ def _where( other = other._values if axis == 0: other = np.reshape(other, (-1, 1)) - if axis == 1: + elif axis == 1: other = np.reshape(other, (1, -1)) other = np.broadcast_to(other, self.shape)