Skip to content

Commit 8757a3c

Browse files
authored
Backport PR pandas-dev#56757 on branch 2.2.x (ENH: Implement interpolation for arrow and masked dtypes) (pandas-dev#56809)
ENH: Implement interpolation for arrow and masked dtypes (pandas-dev#56757) * ENH: Implement interpolation for arrow and masked dtypes * Fixup * Fix typing * Update (cherry picked from commit 5fc2ed2)
1 parent 596ea0b commit 8757a3c

File tree

5 files changed

+143
-7
lines changed

5 files changed

+143
-7
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ Other enhancements
343343
- :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`)
344344
- :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area``; 3rd party :class:`.ExtensionArray` authors need to add this argument to the method ``_pad_or_backfill`` (:issue:`56492`)
345345
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
346+
- Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`)
346347
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
347348
- Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`)
348349
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)

pandas/core/arrays/arrow/array.py

+40
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def floordiv_compat(
182182
AxisInt,
183183
Dtype,
184184
FillnaOptions,
185+
InterpolateOptions,
185186
Iterator,
186187
NpDtype,
187188
NumpySorter,
@@ -2048,6 +2049,45 @@ def _maybe_convert_setitem_value(self, value):
20482049
raise TypeError(msg) from err
20492050
return value
20502051

2052+
def interpolate(
2053+
self,
2054+
*,
2055+
method: InterpolateOptions,
2056+
axis: int,
2057+
index,
2058+
limit,
2059+
limit_direction,
2060+
limit_area,
2061+
copy: bool,
2062+
**kwargs,
2063+
) -> Self:
2064+
"""
2065+
See NDFrame.interpolate.__doc__.
2066+
"""
2067+
# NB: we return type(self) even if copy=False
2068+
mask = self.isna()
2069+
if self.dtype.kind == "f":
2070+
data = self._pa_array.to_numpy()
2071+
elif self.dtype.kind in "iu":
2072+
data = self.to_numpy(dtype="f8", na_value=0.0)
2073+
else:
2074+
raise NotImplementedError(
2075+
f"interpolate is not implemented for dtype={self.dtype}"
2076+
)
2077+
2078+
missing.interpolate_2d_inplace(
2079+
data,
2080+
method=method,
2081+
axis=0,
2082+
index=index,
2083+
limit=limit,
2084+
limit_direction=limit_direction,
2085+
limit_area=limit_area,
2086+
mask=mask,
2087+
**kwargs,
2088+
)
2089+
return type(self)(self._box_pa_array(pa.array(data, mask=mask)))
2090+
20512091
@classmethod
20522092
def _if_else(
20532093
cls,

pandas/core/arrays/masked.py

+54
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AxisInt,
2323
DtypeObj,
2424
FillnaOptions,
25+
InterpolateOptions,
2526
NpDtype,
2627
PositionalIndexer,
2728
Scalar,
@@ -98,6 +99,7 @@
9899
NumpySorter,
99100
NumpyValueArrayLike,
100101
)
102+
from pandas.core.arrays import FloatingArray
101103

102104
from pandas.compat.numpy import function as nv
103105

@@ -1491,6 +1493,58 @@ def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
14911493
else:
14921494
return self.dtype.na_value
14931495

1496+
def interpolate(
1497+
self,
1498+
*,
1499+
method: InterpolateOptions,
1500+
axis: int,
1501+
index,
1502+
limit,
1503+
limit_direction,
1504+
limit_area,
1505+
copy: bool,
1506+
**kwargs,
1507+
) -> FloatingArray:
1508+
"""
1509+
See NDFrame.interpolate.__doc__.
1510+
"""
1511+
# NB: we return type(self) even if copy=False
1512+
if self.dtype.kind == "f":
1513+
if copy:
1514+
data = self._data.copy()
1515+
mask = self._mask.copy()
1516+
else:
1517+
data = self._data
1518+
mask = self._mask
1519+
elif self.dtype.kind in "iu":
1520+
copy = True
1521+
data = self._data.astype("f8")
1522+
mask = self._mask.copy()
1523+
else:
1524+
raise NotImplementedError(
1525+
f"interpolate is not implemented for dtype={self.dtype}"
1526+
)
1527+
1528+
missing.interpolate_2d_inplace(
1529+
data,
1530+
method=method,
1531+
axis=0,
1532+
index=index,
1533+
limit=limit,
1534+
limit_direction=limit_direction,
1535+
limit_area=limit_area,
1536+
mask=mask,
1537+
**kwargs,
1538+
)
1539+
if not copy:
1540+
return self # type: ignore[return-value]
1541+
if self.dtype.kind == "f":
1542+
return type(self)._simple_new(data, mask) # type: ignore[return-value]
1543+
else:
1544+
from pandas.core.arrays import FloatingArray
1545+
1546+
return FloatingArray._simple_new(data, mask)
1547+
14941548
def _accumulate(
14951549
self, name: str, *, skipna: bool = True, **kwargs
14961550
) -> BaseMaskedArray:

pandas/core/missing.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def interpolate_2d_inplace(
349349
limit_direction: str = "forward",
350350
limit_area: str | None = None,
351351
fill_value: Any | None = None,
352+
mask=None,
352353
**kwargs,
353354
) -> None:
354355
"""
@@ -396,6 +397,7 @@ def func(yvalues: np.ndarray) -> None:
396397
limit_area=limit_area_validated,
397398
fill_value=fill_value,
398399
bounds_error=False,
400+
mask=mask,
399401
**kwargs,
400402
)
401403

@@ -440,6 +442,7 @@ def _interpolate_1d(
440442
fill_value: Any | None = None,
441443
bounds_error: bool = False,
442444
order: int | None = None,
445+
mask=None,
443446
**kwargs,
444447
) -> None:
445448
"""
@@ -453,8 +456,10 @@ def _interpolate_1d(
453456
-----
454457
Fills 'yvalues' in-place.
455458
"""
456-
457-
invalid = isna(yvalues)
459+
if mask is not None:
460+
invalid = mask
461+
else:
462+
invalid = isna(yvalues)
458463
valid = ~invalid
459464

460465
if not valid.any():
@@ -531,7 +536,10 @@ def _interpolate_1d(
531536
**kwargs,
532537
)
533538

534-
if is_datetimelike:
539+
if mask is not None:
540+
mask[:] = False
541+
mask[preserve_nans] = True
542+
elif is_datetimelike:
535543
yvalues[preserve_nans] = NaT.value
536544
else:
537545
yvalues[preserve_nans] = np.nan

pandas/tests/frame/methods/test_interpolate.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,41 @@ def test_interpolate_empty_df(self):
508508
assert result is None
509509
tm.assert_frame_equal(df, expected)
510510

511-
def test_interpolate_ea_raise(self):
511+
def test_interpolate_ea(self, any_int_ea_dtype):
512512
# GH#55347
513-
df = DataFrame({"a": [1, None, 2]}, dtype="Int64")
514-
with pytest.raises(NotImplementedError, match="does not implement"):
515-
df.interpolate()
513+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=any_int_ea_dtype)
514+
orig = df.copy()
515+
result = df.interpolate(limit=2)
516+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="Float64")
517+
tm.assert_frame_equal(result, expected)
518+
tm.assert_frame_equal(df, orig)
519+
520+
@pytest.mark.parametrize(
521+
"dtype",
522+
[
523+
"Float64",
524+
"Float32",
525+
pytest.param("float32[pyarrow]", marks=td.skip_if_no("pyarrow")),
526+
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
527+
],
528+
)
529+
def test_interpolate_ea_float(self, dtype):
530+
# GH#55347
531+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype)
532+
orig = df.copy()
533+
result = df.interpolate(limit=2)
534+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype=dtype)
535+
tm.assert_frame_equal(result, expected)
536+
tm.assert_frame_equal(df, orig)
537+
538+
@pytest.mark.parametrize(
539+
"dtype",
540+
["int64", "uint64", "int32", "int16", "int8", "uint32", "uint16", "uint8"],
541+
)
542+
def test_interpolate_arrow(self, dtype):
543+
# GH#55347
544+
pytest.importorskip("pyarrow")
545+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype + "[pyarrow]")
546+
result = df.interpolate(limit=2)
547+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="float64[pyarrow]")
548+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)