Skip to content

Commit 4aba367

Browse files
mroeschkepmhatre1
authored andcommitted
ENH: Implement Series.interpolate for ArrowDtype (pandas-dev#56347)
* ENH: Implement Series.interpolate for ArrowDtype * Min version compat * Fold into interpolate * Remove from 2.2 * Modify tests
1 parent d130a47 commit 4aba367

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

pandas/core/arrays/arrow/array.py

+17
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,23 @@ def interpolate(
20862086
See NDFrame.interpolate.__doc__.
20872087
"""
20882088
# NB: we return type(self) even if copy=False
2089+
if not self.dtype._is_numeric:
2090+
raise ValueError("Values must be numeric.")
2091+
2092+
if (
2093+
not pa_version_under13p0
2094+
and method == "linear"
2095+
and limit_area is None
2096+
and limit is None
2097+
and limit_direction == "forward"
2098+
):
2099+
values = self._pa_array.combine_chunks()
2100+
na_value = pa.array([None], type=values.type)
2101+
y_diff_2 = pc.fill_null_backward(pc.pairwise_diff_checked(values, period=2))
2102+
prev_values = pa.concat_arrays([na_value, values[:-2], na_value])
2103+
interps = pc.add_checked(prev_values, pc.divide_checked(y_diff_2, 2))
2104+
return type(self)(pc.coalesce(self._pa_array, interps))
2105+
20892106
mask = self.isna()
20902107
if self.dtype.kind == "f":
20912108
data = self._pa_array.to_numpy()

pandas/tests/extension/test_arrow.py

+20
Original file line numberDiff line numberDiff line change
@@ -3436,6 +3436,26 @@ def test_string_to_datetime_parsing_cast():
34363436
tm.assert_series_equal(result, expected)
34373437

34383438

3439+
@pytest.mark.skipif(
3440+
pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow"
3441+
)
3442+
def test_interpolate_not_numeric(data):
3443+
if not data.dtype._is_numeric:
3444+
with pytest.raises(ValueError, match="Values must be numeric."):
3445+
pd.Series(data).interpolate()
3446+
3447+
3448+
@pytest.mark.skipif(
3449+
pa_version_under13p0, reason="pairwise_diff_checked not implemented in pyarrow"
3450+
)
3451+
@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"])
3452+
def test_interpolate_linear(dtype):
3453+
ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype)
3454+
result = ser.interpolate()
3455+
expected = pd.Series([None, 1, 2, 3, 4, None], dtype=dtype)
3456+
tm.assert_series_equal(result, expected)
3457+
3458+
34393459
def test_string_to_time_parsing_cast():
34403460
# GH 56463
34413461
string_times = ["11:41:43.076160"]

0 commit comments

Comments
 (0)