Skip to content

Commit 5fc2ed2

Browse files
authored
ENH: Implement interpolation for arrow and masked dtypes (#56757)
* ENH: Implement interpolation for arrow and masked dtypes * Fixup * Fix typing * Update
1 parent fce520d commit 5fc2ed2

File tree

5 files changed

+143
-7
lines changed

5 files changed

+143
-7
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ Other enhancements
343343
- :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`)
344344
- :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area``; 3rd party :class:`.ExtensionArray` authors need to add this argument to the method ``_pad_or_backfill`` (:issue:`56492`)
345345
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
346+
- Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`)
346347
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
347348
- Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`)
348349
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)

pandas/core/arrays/arrow/array.py

+40
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def floordiv_compat(
184184
AxisInt,
185185
Dtype,
186186
FillnaOptions,
187+
InterpolateOptions,
187188
Iterator,
188189
NpDtype,
189190
NumpySorter,
@@ -2068,6 +2069,45 @@ def _maybe_convert_setitem_value(self, value):
20682069
raise TypeError(msg) from err
20692070
return value
20702071

2072+
def interpolate(
2073+
self,
2074+
*,
2075+
method: InterpolateOptions,
2076+
axis: int,
2077+
index,
2078+
limit,
2079+
limit_direction,
2080+
limit_area,
2081+
copy: bool,
2082+
**kwargs,
2083+
) -> Self:
2084+
"""
2085+
See NDFrame.interpolate.__doc__.
2086+
"""
2087+
# NB: we return type(self) even if copy=False
2088+
mask = self.isna()
2089+
if self.dtype.kind == "f":
2090+
data = self._pa_array.to_numpy()
2091+
elif self.dtype.kind in "iu":
2092+
data = self.to_numpy(dtype="f8", na_value=0.0)
2093+
else:
2094+
raise NotImplementedError(
2095+
f"interpolate is not implemented for dtype={self.dtype}"
2096+
)
2097+
2098+
missing.interpolate_2d_inplace(
2099+
data,
2100+
method=method,
2101+
axis=0,
2102+
index=index,
2103+
limit=limit,
2104+
limit_direction=limit_direction,
2105+
limit_area=limit_area,
2106+
mask=mask,
2107+
**kwargs,
2108+
)
2109+
return type(self)(self._box_pa_array(pa.array(data, mask=mask)))
2110+
20712111
@classmethod
20722112
def _if_else(
20732113
cls,

pandas/core/arrays/masked.py

+54
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AxisInt,
2323
DtypeObj,
2424
FillnaOptions,
25+
InterpolateOptions,
2526
NpDtype,
2627
PositionalIndexer,
2728
Scalar,
@@ -99,6 +100,7 @@
99100
NumpyValueArrayLike,
100101
)
101102
from pandas._libs.missing import NAType
103+
from pandas.core.arrays import FloatingArray
102104

103105
from pandas.compat.numpy import function as nv
104106

@@ -1521,6 +1523,58 @@ def all(
15211523
else:
15221524
return self.dtype.na_value
15231525

1526+
def interpolate(
1527+
self,
1528+
*,
1529+
method: InterpolateOptions,
1530+
axis: int,
1531+
index,
1532+
limit,
1533+
limit_direction,
1534+
limit_area,
1535+
copy: bool,
1536+
**kwargs,
1537+
) -> FloatingArray:
1538+
"""
1539+
See NDFrame.interpolate.__doc__.
1540+
"""
1541+
# NB: we return type(self) even if copy=False
1542+
if self.dtype.kind == "f":
1543+
if copy:
1544+
data = self._data.copy()
1545+
mask = self._mask.copy()
1546+
else:
1547+
data = self._data
1548+
mask = self._mask
1549+
elif self.dtype.kind in "iu":
1550+
copy = True
1551+
data = self._data.astype("f8")
1552+
mask = self._mask.copy()
1553+
else:
1554+
raise NotImplementedError(
1555+
f"interpolate is not implemented for dtype={self.dtype}"
1556+
)
1557+
1558+
missing.interpolate_2d_inplace(
1559+
data,
1560+
method=method,
1561+
axis=0,
1562+
index=index,
1563+
limit=limit,
1564+
limit_direction=limit_direction,
1565+
limit_area=limit_area,
1566+
mask=mask,
1567+
**kwargs,
1568+
)
1569+
if not copy:
1570+
return self # type: ignore[return-value]
1571+
if self.dtype.kind == "f":
1572+
return type(self)._simple_new(data, mask) # type: ignore[return-value]
1573+
else:
1574+
from pandas.core.arrays import FloatingArray
1575+
1576+
return FloatingArray._simple_new(data, mask)
1577+
15241578
def _accumulate(
15251579
self, name: str, *, skipna: bool = True, **kwargs
15261580
) -> BaseMaskedArray:

pandas/core/missing.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def interpolate_2d_inplace(
349349
limit_direction: str = "forward",
350350
limit_area: str | None = None,
351351
fill_value: Any | None = None,
352+
mask=None,
352353
**kwargs,
353354
) -> None:
354355
"""
@@ -396,6 +397,7 @@ def func(yvalues: np.ndarray) -> None:
396397
limit_area=limit_area_validated,
397398
fill_value=fill_value,
398399
bounds_error=False,
400+
mask=mask,
399401
**kwargs,
400402
)
401403

@@ -440,6 +442,7 @@ def _interpolate_1d(
440442
fill_value: Any | None = None,
441443
bounds_error: bool = False,
442444
order: int | None = None,
445+
mask=None,
443446
**kwargs,
444447
) -> None:
445448
"""
@@ -453,8 +456,10 @@ def _interpolate_1d(
453456
-----
454457
Fills 'yvalues' in-place.
455458
"""
456-
457-
invalid = isna(yvalues)
459+
if mask is not None:
460+
invalid = mask
461+
else:
462+
invalid = isna(yvalues)
458463
valid = ~invalid
459464

460465
if not valid.any():
@@ -531,7 +536,10 @@ def _interpolate_1d(
531536
**kwargs,
532537
)
533538

534-
if is_datetimelike:
539+
if mask is not None:
540+
mask[:] = False
541+
mask[preserve_nans] = True
542+
elif is_datetimelike:
535543
yvalues[preserve_nans] = NaT.value
536544
else:
537545
yvalues[preserve_nans] = np.nan

pandas/tests/frame/methods/test_interpolate.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,41 @@ def test_interpolate_empty_df(self):
498498
assert result is None
499499
tm.assert_frame_equal(df, expected)
500500

501-
def test_interpolate_ea_raise(self):
501+
def test_interpolate_ea(self, any_int_ea_dtype):
502502
# GH#55347
503-
df = DataFrame({"a": [1, None, 2]}, dtype="Int64")
504-
with pytest.raises(NotImplementedError, match="does not implement"):
505-
df.interpolate()
503+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=any_int_ea_dtype)
504+
orig = df.copy()
505+
result = df.interpolate(limit=2)
506+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="Float64")
507+
tm.assert_frame_equal(result, expected)
508+
tm.assert_frame_equal(df, orig)
509+
510+
@pytest.mark.parametrize(
511+
"dtype",
512+
[
513+
"Float64",
514+
"Float32",
515+
pytest.param("float32[pyarrow]", marks=td.skip_if_no("pyarrow")),
516+
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
517+
],
518+
)
519+
def test_interpolate_ea_float(self, dtype):
520+
# GH#55347
521+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype)
522+
orig = df.copy()
523+
result = df.interpolate(limit=2)
524+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype=dtype)
525+
tm.assert_frame_equal(result, expected)
526+
tm.assert_frame_equal(df, orig)
527+
528+
@pytest.mark.parametrize(
529+
"dtype",
530+
["int64", "uint64", "int32", "int16", "int8", "uint32", "uint16", "uint8"],
531+
)
532+
def test_interpolate_arrow(self, dtype):
533+
# GH#55347
534+
pytest.importorskip("pyarrow")
535+
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype + "[pyarrow]")
536+
result = df.interpolate(limit=2)
537+
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="float64[pyarrow]")
538+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)