Skip to content

Commit e752928

Browse files
authored
CLN: NDFrame._where (#38742)
1 parent 5a6a0f7 commit e752928

File tree

4 files changed

+51
-20
lines changed

4 files changed

+51
-20
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ Sparse
295295

296296
ExtensionArray
297297
^^^^^^^^^^^^^^
298-
298+
- Bug in :meth:`DataFrame.where` when ``other`` is a :class:`Series` with ExtensionArray dtype (:issue:`38729`)
299299
-
300300
-
301301

pandas/core/generic.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@
8989
import pandas as pd
9090
from pandas.core import arraylike, indexing, missing, nanops
9191
import pandas.core.algorithms as algos
92+
from pandas.core.arrays import ExtensionArray
9293
from pandas.core.base import PandasObject, SelectionMixin
9394
import pandas.core.common as com
94-
from pandas.core.construction import create_series_with_explicit_dtype
95+
from pandas.core.construction import create_series_with_explicit_dtype, extract_array
9596
from pandas.core.flags import Flags
9697
from pandas.core.indexes import base as ibase
9798
from pandas.core.indexes.api import (
@@ -8788,6 +8789,9 @@ def _where(
87888789
"""
87898790
inplace = validate_bool_kwarg(inplace, "inplace")
87908791

8792+
if axis is not None:
8793+
axis = self._get_axis_number(axis)
8794+
87918795
# align the cond to same shape as myself
87928796
cond = com.apply_if_callable(cond, self)
87938797
if isinstance(cond, NDFrame):
@@ -8827,22 +8831,39 @@ def _where(
88278831
if other.ndim <= self.ndim:
88288832

88298833
_, other = self.align(
8830-
other, join="left", axis=axis, level=level, fill_value=np.nan
8834+
other,
8835+
join="left",
8836+
axis=axis,
8837+
level=level,
8838+
fill_value=np.nan,
8839+
copy=False,
88318840
)
88328841

88338842
# if we are NOT aligned, raise as we cannot where index
8834-
if axis is None and not all(
8835-
other._get_axis(i).equals(ax) for i, ax in enumerate(self.axes)
8836-
):
8843+
if axis is None and not other._indexed_same(self):
88378844
raise InvalidIndexError
88388845

8846+
elif other.ndim < self.ndim:
8847+
# TODO(EA2D): avoid object-dtype cast in EA case GH#38729
8848+
other = other._values
8849+
if axis == 0:
8850+
other = np.reshape(other, (-1, 1))
8851+
elif axis == 1:
8852+
other = np.reshape(other, (1, -1))
8853+
8854+
other = np.broadcast_to(other, self.shape)
8855+
88398856
# slice me out of the other
88408857
else:
88418858
raise NotImplementedError(
88428859
"cannot align with a higher dimensional NDFrame"
88438860
)
88448861

8845-
if isinstance(other, np.ndarray):
8862+
if not isinstance(other, (MultiIndex, NDFrame)):
8863+
# mainly just catching Index here
8864+
other = extract_array(other, extract_numpy=True)
8865+
8866+
if isinstance(other, (np.ndarray, ExtensionArray)):
88468867

88478868
if other.shape != self.shape:
88488869

@@ -8887,10 +8908,10 @@ def _where(
88878908
else:
88888909
align = self._get_axis_number(axis) == 1
88898910

8890-
if align and isinstance(other, NDFrame):
8891-
other = other.reindex(self._info_axis, axis=self._info_axis_number)
88928911
if isinstance(cond, NDFrame):
8893-
cond = cond.reindex(self._info_axis, axis=self._info_axis_number)
8912+
cond = cond.reindex(
8913+
self._info_axis, axis=self._info_axis_number, copy=False
8914+
)
88948915

88958916
block_axis = self._get_block_manager_axis(axis)
88968917

pandas/core/internals/blocks.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,7 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]:
10641064
# If the default repeat behavior in np.putmask would go in the
10651065
# wrong direction, then explicitly repeat and reshape new instead
10661066
if getattr(new, "ndim", 0) >= 1:
1067-
if self.ndim - 1 == new.ndim and axis == 1:
1068-
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
1069-
new = new.astype(new_values.dtype)
1067+
new = new.astype(new_values.dtype, copy=False)
10701068

10711069
# we require exact matches between the len of the
10721070
# values we are setting (or is compat). np.putmask
@@ -1104,13 +1102,6 @@ def putmask(self, mask, new, axis: int = 0) -> List["Block"]:
11041102
new = new.T
11051103
axis = new_values.ndim - axis - 1
11061104

1107-
# Pseudo-broadcast
1108-
if getattr(new, "ndim", 0) >= 1:
1109-
if self.ndim - 1 == new.ndim:
1110-
new_shape = list(new.shape)
1111-
new_shape.insert(axis, 1)
1112-
new = new.reshape(tuple(new_shape))
1113-
11141105
# operate column-by-column
11151106
def f(mask, val, idx):
11161107

pandas/tests/frame/indexing/test_where.py

+19
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,22 @@ def test_where_categorical_filtering(self):
653653
expected.loc[0, :] = np.nan
654654

655655
tm.assert_equal(result, expected)
656+
657+
def test_where_ea_other(self):
658+
# GH#38729/GH#38742
659+
df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
660+
arr = pd.array([7, pd.NA, 9])
661+
ser = Series(arr)
662+
mask = np.ones(df.shape, dtype=bool)
663+
mask[1, :] = False
664+
665+
# TODO: ideally we would get Int64 instead of object
666+
result = df.where(mask, ser, axis=0)
667+
expected = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}).astype(object)
668+
tm.assert_frame_equal(result, expected)
669+
670+
ser2 = Series(arr[:2], index=["A", "B"])
671+
expected = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]})
672+
expected["B"] = expected["B"].astype(object)
673+
result = df.where(mask, ser2, axis=1)
674+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)