Skip to content

BUG: Series.replace(method='pad') with EA dtypes #44270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,14 @@ def __getitem__(
result = self._from_backing_data(result)
return result

def _pad_mask_inplace(
self, method: str, limit, mask: npt.NDArray[np.bool_]
) -> None:
# (for now) when self.ndim == 2, we assume axis=0
func = missing.get_fill_func(method, ndim=self.ndim)
func(self._ndarray.T, limit=limit, mask=mask.T)
return

@doc(ExtensionArray.fillna)
def fillna(
self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
Expand Down
18 changes: 18 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,24 @@ def _where(
result[~mask] = val
return result

def _pad_mask_inplace(
self, method: str, limit, mask: npt.NDArray[np.bool_]
) -> None:
"""
Replace values in locations specified by 'mask' using pad or backfill.

See also
--------
ExtensionArray.fillna
"""
func = missing.get_fill_func(method)
# NB: if we don't copy mask here, it may be altered inplace, which
# would mess up the `self[mask] = ...` below.
new_values, _ = func(self.astype(object), limit=limit, mask=mask.copy())
new_values = self._from_sequence(new_values, dtype=self.dtype)
self[mask] = new_values[mask]
return

@classmethod
def _empty(cls, shape: Shape, dtype: ExtensionDtype):
"""
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6518,10 +6518,13 @@ def replace(

if isinstance(to_replace, (tuple, list)):
if isinstance(self, ABCDataFrame):
return self.apply(
result = self.apply(
self._constructor_sliced._replace_single,
args=(to_replace, method, inplace, limit),
)
if inplace:
return
return result
self = cast("Series", self)
return self._replace_single(to_replace, method, inplace, limit)

Expand Down
3 changes: 3 additions & 0 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def mask_missing(arr: ArrayLike, values_to_mask) -> npt.NDArray[np.bool_]:
if na_mask.any():
mask |= isna(arr)

if not isinstance(mask, np.ndarray):
# e.g. if arr is IntegerArray, then mask is BooleanArray
mask = mask.to_numpy(dtype=bool, na_value=False)
return mask


Expand Down
20 changes: 10 additions & 10 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4889,23 +4889,23 @@ def _replace_single(self, to_replace, method: str, inplace: bool, limit):
replacement value is given in the replace method
"""

orig_dtype = self.dtype
result = self if inplace else self.copy()
fill_f = missing.get_fill_func(method)

mask = missing.mask_missing(result.values, to_replace)
values, _ = fill_f(result.values, limit=limit, mask=mask)
values = result._values
mask = missing.mask_missing(values, to_replace)

if values.dtype == orig_dtype and inplace:
return

result = self._constructor(values, index=self.index, dtype=self.dtype)
result = result.__finalize__(self)
if isinstance(values, ExtensionArray):
# dispatch to the EA's _pad_mask_inplace method
values._pad_mask_inplace(method, limit, mask)
else:
fill_f = missing.get_fill_func(method)
values, _ = fill_f(values, limit=limit, mask=mask)

if inplace:
self._update_inplace(result)
return

result = self._constructor(values, index=self.index, dtype=self.dtype)
result = result.__finalize__(self)
return result

# error: Cannot determine type of 'shift'
Expand Down
50 changes: 50 additions & 0 deletions pandas/tests/series/methods/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,56 @@ def test_replace_extension_other(self, frame_or_series):
# should not have changed dtype
tm.assert_equal(obj, result)

def _check_replace_with_method(self, ser: pd.Series):
df = ser.to_frame()

res = ser.replace(ser[1], method="pad")
expected = pd.Series([ser[0], ser[0]] + list(ser[2:]), dtype=ser.dtype)
tm.assert_series_equal(res, expected)

res_df = df.replace(ser[1], method="pad")
tm.assert_frame_equal(res_df, expected.to_frame())

ser2 = ser.copy()
res2 = ser2.replace(ser[1], method="pad", inplace=True)
assert res2 is None
tm.assert_series_equal(ser2, expected)

res_df2 = df.replace(ser[1], method="pad", inplace=True)
assert res_df2 is None
tm.assert_frame_equal(df, expected.to_frame())

def test_replace_ea_dtype_with_method(self, any_numeric_ea_dtype):
arr = pd.array([1, 2, pd.NA, 4], dtype=any_numeric_ea_dtype)
ser = pd.Series(arr)

self._check_replace_with_method(ser)

@pytest.mark.parametrize("as_categorical", [True, False])
def test_replace_interval_with_method(self, as_categorical):
# in particular interval that can't hold NA

idx = pd.IntervalIndex.from_breaks(range(4))
ser = pd.Series(idx)
if as_categorical:
ser = ser.astype("category")

self._check_replace_with_method(ser)

@pytest.mark.parametrize("as_period", [True, False])
@pytest.mark.parametrize("as_categorical", [True, False])
def test_replace_datetimelike_with_method(self, as_period, as_categorical):
idx = pd.date_range("2016-01-01", periods=5, tz="US/Pacific")
if as_period:
idx = idx.tz_localize(None).to_period("D")

ser = pd.Series(idx)
ser.iloc[-2] = pd.NaT
if as_categorical:
ser = ser.astype("category")

self._check_replace_with_method(ser)

def test_replace_with_compiled_regex(self):
# https://github.com/pandas-dev/pandas/issues/35680
s = pd.Series(["a", "b", "c"])
Expand Down