Skip to content

ENH: Implement interpolation for arrow and masked dtypes #56757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ Other enhancements
- :meth:`ExtensionArray.duplicated` added to allow extension type implementations of the ``duplicated`` method (:issue:`55255`)
- :meth:`Series.ffill`, :meth:`Series.bfill`, :meth:`DataFrame.ffill`, and :meth:`DataFrame.bfill` have gained the argument ``limit_area``; 3rd party :class:`.ExtensionArray` authors need to add this argument to the method ``_pad_or_backfill`` (:issue:`56492`)
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
- Implement :meth:`Series.interpolate` and :meth:`DataFrame.interpolate` for :class:`ArrowDtype` and masked dtypes (:issue:`56267`)
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
- Implemented :meth:`Series.dt` methods and attributes for :class:`ArrowDtype` with ``pyarrow.duration`` type (:issue:`52284`)
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)
Expand Down
40 changes: 40 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def floordiv_compat(
AxisInt,
Dtype,
FillnaOptions,
InterpolateOptions,
Iterator,
NpDtype,
NumpySorter,
Expand Down Expand Up @@ -2068,6 +2069,45 @@ def _maybe_convert_setitem_value(self, value):
raise TypeError(msg) from err
return value

def interpolate(
self,
*,
method: InterpolateOptions,
axis: int,
index,
limit,
limit_direction,
limit_area,
copy: bool,
**kwargs,
) -> Self:
"""
See NDFrame.interpolate.__doc__.
"""
# NB: we return type(self) even if copy=False
mask = self.isna()
if self.dtype.kind == "f":
data = self._pa_array.to_numpy()
elif self.dtype.kind in "iu":
data = self.to_numpy(dtype="f8", na_value=0.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, the na_value doesn't matter here because of the mask right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes correct

else:
raise NotImplementedError(
f"interpolate is not implemented for dtype={self.dtype}"
)

missing.interpolate_2d_inplace(
data,
method=method,
axis=0,
index=index,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
mask=mask,
**kwargs,
)
return type(self)(self._box_pa_array(pa.array(data, mask=mask)))

@classmethod
def _if_else(
cls,
Expand Down
54 changes: 54 additions & 0 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AxisInt,
DtypeObj,
FillnaOptions,
InterpolateOptions,
NpDtype,
PositionalIndexer,
Scalar,
Expand Down Expand Up @@ -99,6 +100,7 @@
NumpyValueArrayLike,
)
from pandas._libs.missing import NAType
from pandas.core.arrays import FloatingArray

from pandas.compat.numpy import function as nv

Expand Down Expand Up @@ -1519,6 +1521,58 @@ def all(
else:
return self.dtype.na_value

def interpolate(
self,
*,
method: InterpolateOptions,
axis: int,
index,
limit,
limit_direction,
limit_area,
copy: bool,
**kwargs,
) -> FloatingArray:
"""
See NDFrame.interpolate.__doc__.
"""
# NB: we return type(self) even if copy=False
if self.dtype.kind == "f":
if copy:
data = self._data.copy()
mask = self._mask.copy()
else:
data = self._data
mask = self._mask
elif self.dtype.kind in "iu":
copy = True
data = self._data.astype("f8")
mask = self._mask.copy()
else:
raise NotImplementedError(
f"interpolate is not implemented for dtype={self.dtype}"
)

missing.interpolate_2d_inplace(
data,
method=method,
axis=0,
index=index,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
mask=mask,
**kwargs,
)
if not copy:
return self # type: ignore[return-value]
if self.dtype.kind == "f":
return type(self)._simple_new(data, mask) # type: ignore[return-value]
else:
from pandas.core.arrays import FloatingArray

return FloatingArray._simple_new(data, mask)

def _accumulate(
self, name: str, *, skipna: bool = True, **kwargs
) -> BaseMaskedArray:
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def interpolate_2d_inplace(
limit_direction: str = "forward",
limit_area: str | None = None,
fill_value: Any | None = None,
mask=None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -382,6 +383,7 @@ def func(yvalues: np.ndarray) -> None:
limit_area=limit_area_validated,
fill_value=fill_value,
bounds_error=False,
mask=mask,
**kwargs,
)

Expand Down Expand Up @@ -426,6 +428,7 @@ def _interpolate_1d(
fill_value: Any | None = None,
bounds_error: bool = False,
order: int | None = None,
mask=None,
**kwargs,
) -> None:
"""
Expand All @@ -439,8 +442,10 @@ def _interpolate_1d(
-----
Fills 'yvalues' in-place.
"""

invalid = isna(yvalues)
if mask is not None:
invalid = mask
else:
invalid = isna(yvalues)
valid = ~invalid

if not valid.any():
Expand Down Expand Up @@ -517,7 +522,10 @@ def _interpolate_1d(
**kwargs,
)

if is_datetimelike:
if mask is not None:
mask[:] = False
mask[preserve_nans] = True
elif is_datetimelike:
yvalues[preserve_nans] = NaT.value
else:
yvalues[preserve_nans] = np.nan
Expand Down
41 changes: 37 additions & 4 deletions pandas/tests/frame/methods/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,41 @@ def test_interpolate_empty_df(self):
assert result is None
tm.assert_frame_equal(df, expected)

def test_interpolate_ea_raise(self):
def test_interpolate_ea(self, any_int_ea_dtype):
# GH#55347
df = DataFrame({"a": [1, None, 2]}, dtype="Int64")
with pytest.raises(NotImplementedError, match="does not implement"):
df.interpolate()
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=any_int_ea_dtype)
orig = df.copy()
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="Float64")
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(df, orig)

@pytest.mark.parametrize(
"dtype",
[
"Float64",
"Float32",
pytest.param("float32[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_interpolate_ea_float(self, dtype):
# GH#55347
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype)
orig = df.copy()
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype=dtype)
tm.assert_frame_equal(result, expected)
tm.assert_frame_equal(df, orig)

@pytest.mark.parametrize(
"dtype",
["int64", "uint64", "int32", "int16", "int8", "uint32", "uint16", "uint8"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any_int_numpy_dtype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are also passing in int, which screws with the logic below, I think handling this specifically causes more complexity than simply keeping the explicit list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. Sure this is fine for now then

)
def test_interpolate_arrow(self, dtype):
# GH#55347
pytest.importorskip("pyarrow")
df = DataFrame({"a": [1, None, None, None, 3]}, dtype=dtype + "[pyarrow]")
result = df.interpolate(limit=2)
expected = DataFrame({"a": [1, 1.5, 2.0, None, 3]}, dtype="float64[pyarrow]")
tm.assert_frame_equal(result, expected)