Skip to content

ENH: NDArrayBackedExtensionArray.fillna(method) with 2d #40294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 9, 2021
8 changes: 6 additions & 2 deletions pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,12 @@ def fillna(

if mask.any():
if method is not None:
func = missing.get_fill_func(method)
new_values, _ = func(self._ndarray.copy(), limit=limit, mask=mask)
# TODO: check value is None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already done in validate_fillna_kwargs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, will update (possibly in upcoming CLN PR to avoid clogging the CI right now)

# (for now) when self.ndim == 2, we assume axis=0
func = missing.get_fill_func(method, ndim=self.ndim)
new_values, _ = func(self._ndarray.T.copy(), limit=limit, mask=mask.T)
new_values = new_values.T

# TODO: PandasArray didn't used to copy, need tests for this
new_values = self._from_backing_data(new_values)
else:
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,14 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
m8arr = self._ndarray.view("M8[ns]")
return m8arr.searchsorted(value, side=side, sorter=sorter)

def fillna(self, value=None, method=None, limit=None) -> PeriodArray:
if method is not None:
# view as dt64 so we get treated as timelike in core.missing
dta = self.view("M8[ns]")
result = dta.fillna(value=value, method=method, limit=limit)
return result.view(self.dtype)
return super().fillna(value=value, method=method, limit=limit)

# ------------------------------------------------------------------
# Arithmetic Methods

Expand Down
12 changes: 4 additions & 8 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,6 @@ def interpolate_2d(
values,
)

orig_values = values

transf = (lambda x: x) if axis == 0 else (lambda x: x.T)

# reshape a 1 dim if needed
Expand All @@ -669,10 +667,6 @@ def interpolate_2d(
if ndim == 1:
result = result[0]

if orig_values.dtype.kind in ["m", "M"]:
# convert float back to datetime64/timedelta64
result = result.view(orig_values.dtype)

return result


Expand Down Expand Up @@ -755,9 +749,11 @@ def _backfill_2d(values, limit=None, mask=None):
_fill_methods = {"pad": _pad_1d, "backfill": _backfill_1d}


def get_fill_func(method):
def get_fill_func(method, ndim: int = 1):
method = clean_fill_method(method)
return _fill_methods[method]
if ndim == 1:
return _fill_methods[method]
return {"pad": _pad_2d, "backfill": _backfill_2d}[method]


def clean_reindex_fill_method(method):
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,37 @@ def test_fillna_preserves_tz(self, method):
assert arr[2] is pd.NaT
assert dti[2] == pd.Timestamp("2000-01-03", tz="US/Central")

def test_fillna_2d(self):
dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific")
dta = dti._data.reshape(3, 2).copy()
dta[0, 1] = pd.NaT
dta[1, 0] = pd.NaT

res1 = dta.fillna(method="pad")
expected1 = dta.copy()
expected1[1, 0] = dta[0, 0]
tm.assert_extension_array_equal(res1, expected1)

res2 = dta.fillna(method="backfill")
expected2 = dta.copy()
expected2 = dta.copy()
expected2[1, 0] = dta[2, 0]
expected2[0, 1] = dta[1, 1]
tm.assert_extension_array_equal(res2, expected2)

# with different ordering for underlying ndarray; behavior should
# be unchanged
dta2 = dta._from_backing_data(dta._ndarray.copy(order="F"))
assert dta2._ndarray.flags["F_CONTIGUOUS"]
assert not dta2._ndarray.flags["C_CONTIGUOUS"]
tm.assert_extension_array_equal(dta, dta2)

res3 = dta2.fillna(method="pad")
tm.assert_extension_array_equal(res3, expected1)

res4 = dta2.fillna(method="backfill")
tm.assert_extension_array_equal(res4, expected2)

def test_array_interface_tz(self):
tz = "US/Central"
data = DatetimeArray(pd.date_range("2017", periods=2, tz=tz))
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/extension/base/dim2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ def test_concat_2d(self, data):
with pytest.raises(ValueError):
left._concat_same_type([left, right], axis=2)

@pytest.mark.parametrize("method", ["backfill", "pad"])
def test_fillna_2d_method(self, data_missing, method):
arr = data_missing.repeat(2).reshape(2, 2)
assert arr[0].isna().all()
assert not arr[1].isna().any()

result = arr.fillna(method=method)

expected = data_missing.fillna(method=method).repeat(2).reshape(2, 2)
self.assert_extension_array_equal(result, expected)

@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
def test_reductions_2d_axis_none(self, data, method, request):
if not hasattr(data, method):
Expand Down