Skip to content

Commit 2e367c5

Browse files
lukemanleymliu08
authored andcommitted
PERF: Series.fillna for pyarrow-backed dtypes (pandas-dev#49722)
* ArrowExtensionArray.fillna perf * whatsnew * fixes * tighter try/excepts * cleanup * test for performance warning * test for performance warning
1 parent e32b504 commit 2e367c5

File tree

6 files changed

+142
-10
lines changed

6 files changed

+142
-10
lines changed

asv_bench/benchmarks/series_methods.py

+42
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,48 @@ def time_dropna(self, dtype):
7979
self.s.dropna()
8080

8181

82+
class Fillna:
83+
84+
params = [
85+
[
86+
"datetime64[ns]",
87+
"float64",
88+
"Int64",
89+
"int64[pyarrow]",
90+
"string",
91+
"string[pyarrow]",
92+
],
93+
[None, "pad", "backfill"],
94+
]
95+
param_names = ["dtype", "method"]
96+
97+
def setup(self, dtype, method):
98+
N = 10**6
99+
if dtype == "datetime64[ns]":
100+
data = date_range("2000-01-01", freq="S", periods=N)
101+
na_value = NaT
102+
elif dtype == "float64":
103+
data = np.random.randn(N)
104+
na_value = np.nan
105+
elif dtype in ("Int64", "int64[pyarrow]"):
106+
data = np.arange(N)
107+
na_value = NA
108+
elif dtype in ("string", "string[pyarrow]"):
109+
data = tm.rands_array(5, N)
110+
na_value = NA
111+
else:
112+
raise NotImplementedError
113+
fill_value = data[0]
114+
ser = Series(data, dtype=dtype)
115+
ser[::2] = na_value
116+
self.ser = ser
117+
self.fill_value = fill_value
118+
119+
def time_fillna(self, dtype, method):
120+
value = self.fill_value if method is None else None
121+
self.ser.fillna(value=value, method=method)
122+
123+
82124
class SearchSorted:
83125

84126
goal_time = 0.2

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ Performance improvements
586586
- Performance improvement in :meth:`.DataFrameGroupBy.mean`, :meth:`.SeriesGroupBy.mean`, :meth:`.DataFrameGroupBy.var`, and :meth:`.SeriesGroupBy.var` for extension array dtypes (:issue:`37493`)
587587
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`, :issue:`49577`)
588588
- Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`)
589+
- Performance improvement in :meth:`Series.fillna` for pyarrow-backed dtypes (:issue:`49722`)
589590
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
590591
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)
591592
- Performance improvement for :class:`DatetimeIndex` constructor passing a list (:issue:`48609`)

pandas/core/arrays/arrow/array.py

+64
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
TYPE_CHECKING,
55
Any,
66
TypeVar,
7+
cast,
78
)
89

910
import numpy as np
1011

1112
from pandas._typing import (
13+
ArrayLike,
1214
Dtype,
15+
FillnaOptions,
1316
PositionalIndexer,
1417
SortKind,
1518
TakeIndexer,
@@ -20,6 +23,7 @@
2023
pa_version_under7p0,
2124
)
2225
from pandas.util._decorators import doc
26+
from pandas.util._validators import validate_fillna_kwargs
2327

2428
from pandas.core.dtypes.common import (
2529
is_array_like,
@@ -521,6 +525,66 @@ def dropna(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT:
521525
else:
522526
return type(self)(pc.drop_null(self._data))
523527

528+
@doc(ExtensionArray.fillna)
529+
def fillna(
530+
self: ArrowExtensionArrayT,
531+
value: object | ArrayLike | None = None,
532+
method: FillnaOptions | None = None,
533+
limit: int | None = None,
534+
) -> ArrowExtensionArrayT:
535+
536+
value, method = validate_fillna_kwargs(value, method)
537+
538+
if limit is not None:
539+
return super().fillna(value=value, method=method, limit=limit)
540+
541+
if method is not None and pa_version_under7p0:
542+
# fill_null_{forward|backward} added in pyarrow 7.0
543+
fallback_performancewarning(version="7")
544+
return super().fillna(value=value, method=method, limit=limit)
545+
546+
if is_array_like(value):
547+
value = cast(ArrayLike, value)
548+
if len(value) != len(self):
549+
raise ValueError(
550+
f"Length of 'value' does not match. Got ({len(value)}) "
551+
f" expected {len(self)}"
552+
)
553+
554+
def convert_fill_value(value, pa_type, dtype):
555+
if value is None:
556+
return value
557+
if isinstance(value, (pa.Scalar, pa.Array, pa.ChunkedArray)):
558+
return value
559+
if is_array_like(value):
560+
pa_box = pa.array
561+
else:
562+
pa_box = pa.scalar
563+
try:
564+
value = pa_box(value, type=pa_type, from_pandas=True)
565+
except pa.ArrowTypeError as err:
566+
msg = f"Invalid value '{str(value)}' for dtype {dtype}"
567+
raise TypeError(msg) from err
568+
return value
569+
570+
fill_value = convert_fill_value(value, self._data.type, self.dtype)
571+
572+
try:
573+
if method is None:
574+
return type(self)(pc.fill_null(self._data, fill_value=fill_value))
575+
elif method == "pad":
576+
return type(self)(pc.fill_null_forward(self._data))
577+
elif method == "backfill":
578+
return type(self)(pc.fill_null_backward(self._data))
579+
except pa.ArrowNotImplementedError:
580+
# ArrowNotImplementedError: Function 'coalesce' has no kernel
581+
# matching input types (duration[ns], duration[ns])
582+
# TODO: remove try/except wrapper if/when pyarrow implements
583+
# a kernel for duration types.
584+
pass
585+
586+
return super().fillna(value=value, method=method, limit=limit)
587+
524588
def isin(self, values) -> npt.NDArray[np.bool_]:
525589
# short-circuit to return all False array.
526590
if not len(values):

pandas/tests/arrays/string_/test_string.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -408,14 +408,6 @@ def test_min_max_numpy(method, box, dtype, request):
408408
def test_fillna_args(dtype, request):
409409
# GH 37987
410410

411-
if dtype.storage == "pyarrow":
412-
reason = (
413-
"Regex pattern \"Cannot set non-string value '1' into "
414-
"a StringArray.\" does not match 'Scalar must be NA or str'"
415-
)
416-
mark = pytest.mark.xfail(raises=AssertionError, reason=reason)
417-
request.node.add_marker(mark)
418-
419411
arr = pd.array(["a", pd.NA], dtype=dtype)
420412

421413
res = arr.fillna(value="b")
@@ -426,8 +418,13 @@ def test_fillna_args(dtype, request):
426418
expected = pd.array(["a", "b"], dtype=dtype)
427419
tm.assert_extension_array_equal(res, expected)
428420

429-
msg = "Cannot set non-string value '1' into a StringArray."
430-
with pytest.raises(ValueError, match=msg):
421+
if dtype.storage == "pyarrow":
422+
err = TypeError
423+
msg = "Invalid value '1' for dtype string"
424+
else:
425+
err = ValueError
426+
msg = "Cannot set non-string value '1' into a StringArray."
427+
with pytest.raises(err, match=msg):
431428
arr.fillna(value=1)
432429

433430

pandas/tests/extension/test_arrow.py

+12
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,18 @@ class TestBaseMissing(base.BaseMissingTests):
632632
def test_dropna_array(self, data_missing):
633633
super().test_dropna_array(data_missing)
634634

635+
def test_fillna_no_op_returns_copy(self, data):
636+
with tm.maybe_produces_warning(
637+
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
638+
):
639+
super().test_fillna_no_op_returns_copy(data)
640+
641+
def test_fillna_series_method(self, data_missing, fillna_method):
642+
with tm.maybe_produces_warning(
643+
PerformanceWarning, pa_version_under7p0, check_stacklevel=False
644+
):
645+
super().test_fillna_series_method(data_missing, fillna_method)
646+
635647

636648
class TestBasePrinting(base.BasePrintingTests):
637649
pass

pandas/tests/extension/test_string.py

+16
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,22 @@ def test_dropna_array(self, data_missing):
168168
expected = data_missing[[1]]
169169
self.assert_extension_array_equal(result, expected)
170170

171+
def test_fillna_no_op_returns_copy(self, data):
172+
with tm.maybe_produces_warning(
173+
PerformanceWarning,
174+
pa_version_under7p0 and data.dtype.storage == "pyarrow",
175+
check_stacklevel=False,
176+
):
177+
super().test_fillna_no_op_returns_copy(data)
178+
179+
def test_fillna_series_method(self, data_missing, fillna_method):
180+
with tm.maybe_produces_warning(
181+
PerformanceWarning,
182+
pa_version_under7p0 and data_missing.dtype.storage == "pyarrow",
183+
check_stacklevel=False,
184+
):
185+
super().test_fillna_series_method(data_missing, fillna_method)
186+
171187

172188
class TestNoReduce(base.BaseNoReduceTests):
173189
@pytest.mark.parametrize("skipna", [True, False])

0 commit comments

Comments
 (0)