Skip to content

Commit 602ab16

Browse files
authored
ENH: NDArrayBackedExtensionArray.fillna(method) with 2d (#40294)
1 parent c29facc commit 602ab16

File tree

5 files changed

+60
-10
lines changed

5 files changed

+60
-10
lines changed

pandas/core/arrays/_mixins.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,12 @@ def fillna(
278278

279279
if mask.any():
280280
if method is not None:
281-
func = missing.get_fill_func(method)
282-
new_values, _ = func(self._ndarray.copy(), limit=limit, mask=mask)
281+
# TODO: check value is None
282+
# (for now) when self.ndim == 2, we assume axis=0
283+
func = missing.get_fill_func(method, ndim=self.ndim)
284+
new_values, _ = func(self._ndarray.T.copy(), limit=limit, mask=mask.T)
285+
new_values = new_values.T
286+
283287
# TODO: PandasArray didn't used to copy, need tests for this
284288
new_values = self._from_backing_data(new_values)
285289
else:

pandas/core/arrays/period.py

+8
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,14 @@ def searchsorted(self, value, side="left", sorter=None) -> np.ndarray:
639639
m8arr = self._ndarray.view("M8[ns]")
640640
return m8arr.searchsorted(value, side=side, sorter=sorter)
641641

642+
def fillna(self, value=None, method=None, limit=None) -> PeriodArray:
643+
if method is not None:
644+
# view as dt64 so we get treated as timelike in core.missing
645+
dta = self.view("M8[ns]")
646+
result = dta.fillna(value=value, method=method, limit=limit)
647+
return result.view(self.dtype)
648+
return super().fillna(value=value, method=method, limit=limit)
649+
642650
# ------------------------------------------------------------------
643651
# Arithmetic Methods
644652

pandas/core/missing.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,6 @@ def interpolate_2d(
646646
values,
647647
)
648648

649-
orig_values = values
650-
651649
transf = (lambda x: x) if axis == 0 else (lambda x: x.T)
652650

653651
# reshape a 1 dim if needed
@@ -669,10 +667,6 @@ def interpolate_2d(
669667
if ndim == 1:
670668
result = result[0]
671669

672-
if orig_values.dtype.kind in ["m", "M"]:
673-
# convert float back to datetime64/timedelta64
674-
result = result.view(orig_values.dtype)
675-
676670
return result
677671

678672

@@ -755,9 +749,11 @@ def _backfill_2d(values, limit=None, mask=None):
755749
_fill_methods = {"pad": _pad_1d, "backfill": _backfill_1d}
756750

757751

758-
def get_fill_func(method):
752+
def get_fill_func(method, ndim: int = 1):
759753
method = clean_fill_method(method)
760-
return _fill_methods[method]
754+
if ndim == 1:
755+
return _fill_methods[method]
756+
return {"pad": _pad_2d, "backfill": _backfill_2d}[method]
761757

762758

763759
def clean_reindex_fill_method(method):

pandas/tests/arrays/test_datetimes.py

+31
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,37 @@ def test_fillna_preserves_tz(self, method):
195195
assert arr[2] is pd.NaT
196196
assert dti[2] == pd.Timestamp("2000-01-03", tz="US/Central")
197197

198+
def test_fillna_2d(self):
199+
dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific")
200+
dta = dti._data.reshape(3, 2).copy()
201+
dta[0, 1] = pd.NaT
202+
dta[1, 0] = pd.NaT
203+
204+
res1 = dta.fillna(method="pad")
205+
expected1 = dta.copy()
206+
expected1[1, 0] = dta[0, 0]
207+
tm.assert_extension_array_equal(res1, expected1)
208+
209+
res2 = dta.fillna(method="backfill")
210+
expected2 = dta.copy()
211+
expected2 = dta.copy()
212+
expected2[1, 0] = dta[2, 0]
213+
expected2[0, 1] = dta[1, 1]
214+
tm.assert_extension_array_equal(res2, expected2)
215+
216+
# with different ordering for underlying ndarray; behavior should
217+
# be unchanged
218+
dta2 = dta._from_backing_data(dta._ndarray.copy(order="F"))
219+
assert dta2._ndarray.flags["F_CONTIGUOUS"]
220+
assert not dta2._ndarray.flags["C_CONTIGUOUS"]
221+
tm.assert_extension_array_equal(dta, dta2)
222+
223+
res3 = dta2.fillna(method="pad")
224+
tm.assert_extension_array_equal(res3, expected1)
225+
226+
res4 = dta2.fillna(method="backfill")
227+
tm.assert_extension_array_equal(res4, expected2)
228+
198229
def test_array_interface_tz(self):
199230
tz = "US/Central"
200231
data = DatetimeArray(pd.date_range("2017", periods=2, tz=tz))

pandas/tests/extension/base/dim2.py

+11
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,17 @@ def test_concat_2d(self, data):
131131
with pytest.raises(ValueError):
132132
left._concat_same_type([left, right], axis=2)
133133

134+
@pytest.mark.parametrize("method", ["backfill", "pad"])
135+
def test_fillna_2d_method(self, data_missing, method):
136+
arr = data_missing.repeat(2).reshape(2, 2)
137+
assert arr[0].isna().all()
138+
assert not arr[1].isna().any()
139+
140+
result = arr.fillna(method=method)
141+
142+
expected = data_missing.fillna(method=method).repeat(2).reshape(2, 2)
143+
self.assert_extension_array_equal(result, expected)
144+
134145
@pytest.mark.parametrize("method", ["mean", "median", "var", "std", "sum", "prod"])
135146
def test_reductions_2d_axis_none(self, data, method, request):
136147
if not hasattr(data, method):

0 commit comments

Comments
 (0)