Skip to content

Commit b5d4e89

Browse files
authored
ENH: Implement cum* methods for PyArrow strings (#60633)
* ENH: Implement cum* methods for PyArrow strings * cleanup * Cleanup * fixup * Fix extension tests * xfail test when there is no pyarrow * mypy fixups * Change logic & whatsnew * Change logic & whatsnew * Fix fixture * Fixup
1 parent 1708e90 commit b5d4e89

File tree

8 files changed

+155
-11
lines changed

8 files changed

+155
-11
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ Other enhancements
3535
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
3636
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
3737
updated to work correctly with NumPy >= 2 (:issue:`57739`)
38+
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
3839
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
39-
-
4040

4141
.. ---------------------------------------------------------------------------
4242
.. _whatsnew_230.notable_bug_fixes:

pandas/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,22 @@ def nullable_string_dtype(request):
13171317
return request.param
13181318

13191319

1320+
@pytest.fixture(
1321+
params=[
1322+
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1323+
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
1324+
]
1325+
)
1326+
def pyarrow_string_dtype(request):
1327+
"""
1328+
Parametrized fixture for string dtypes backed by Pyarrow.
1329+
1330+
* 'str[pyarrow]'
1331+
* 'string[pyarrow]'
1332+
"""
1333+
return pd.StringDtype(*request.param)
1334+
1335+
13201336
@pytest.fixture(
13211337
params=[
13221338
"python",

pandas/core/arrays/arrow/array.py

+55
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_list_like,
4242
is_numeric_dtype,
4343
is_scalar,
44+
is_string_dtype,
4445
pandas_dtype,
4546
)
4647
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -1619,6 +1620,9 @@ def _accumulate(
16191620
------
16201621
NotImplementedError : subclass does not define accumulations
16211622
"""
1623+
if is_string_dtype(self):
1624+
return self._str_accumulate(name=name, skipna=skipna, **kwargs)
1625+
16221626
pyarrow_name = {
16231627
"cummax": "cumulative_max",
16241628
"cummin": "cumulative_min",
@@ -1654,6 +1658,57 @@ def _accumulate(
16541658

16551659
return type(self)(result)
16561660

1661+
def _str_accumulate(
1662+
self, name: str, *, skipna: bool = True, **kwargs
1663+
) -> ArrowExtensionArray | ExtensionArray:
1664+
"""
1665+
Accumulate implementation for strings, see `_accumulate` docstring for details.
1666+
1667+
pyarrow.compute does not implement these methods for strings.
1668+
"""
1669+
if name == "cumprod":
1670+
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
1671+
raise TypeError(msg)
1672+
1673+
# We may need to strip out trailing NA values
1674+
tail: pa.array | None = None
1675+
na_mask: pa.array | None = None
1676+
pa_array = self._pa_array
1677+
np_func = {
1678+
"cumsum": np.cumsum,
1679+
"cummin": np.minimum.accumulate,
1680+
"cummax": np.maximum.accumulate,
1681+
}[name]
1682+
1683+
if self._hasna:
1684+
na_mask = pc.is_null(pa_array)
1685+
if pc.all(na_mask) == pa.scalar(True):
1686+
return type(self)(pa_array)
1687+
if skipna:
1688+
if name == "cumsum":
1689+
pa_array = pc.fill_null(pa_array, "")
1690+
else:
1691+
# We can retain the running min/max by forward/backward filling.
1692+
pa_array = pc.fill_null_forward(pa_array)
1693+
pa_array = pc.fill_null_backward(pa_array)
1694+
else:
1695+
# When not skipping NA values, the result should be null from
1696+
# the first NA value onward.
1697+
idx = pc.index(na_mask, True).as_py()
1698+
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type)
1699+
pa_array = pa_array[:idx]
1700+
1701+
# error: Cannot call function of unknown type
1702+
pa_result = pa.array(np_func(pa_array), type=pa_array.type) # type: ignore[operator]
1703+
1704+
if tail is not None:
1705+
pa_result = pa.concat_arrays([pa_result, tail])
1706+
elif na_mask is not None:
1707+
pa_result = pc.if_else(na_mask, None, pa_result)
1708+
1709+
result = type(self)(pa_result)
1710+
return result
1711+
16571712
def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar:
16581713
"""
16591714
Return a pyarrow scalar result of performing the reduction operation.

pandas/tests/apply/test_str.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import numpy as np
55
import pytest
66

7-
from pandas.compat import WASM
7+
from pandas.compat import (
8+
HAS_PYARROW,
9+
WASM,
10+
)
811

912
from pandas.core.dtypes.common import is_number
1013

@@ -163,10 +166,10 @@ def test_agg_cython_table_transform_series(request, series, func, expected):
163166
# GH21224
164167
# test transforming functions in
165168
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
166-
if series.dtype == "string" and func == "cumsum":
169+
if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
167170
request.applymarker(
168171
pytest.mark.xfail(
169-
raises=(TypeError, NotImplementedError),
172+
raises=NotImplementedError,
170173
reason="TODO(infer_string) cumsum not yet implemented for string",
171174
)
172175
)

pandas/tests/extension/base/accumulate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
1818
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
1919
try:
2020
alt = ser.astype("float64")
21-
except TypeError:
22-
# e.g. Period can't be cast to float64
21+
except (TypeError, ValueError):
22+
# e.g. Period can't be cast to float64 (TypeError)
23+
# String can't be cast to float64 (ValueError)
2324
alt = ser.astype(object)
2425

2526
result = getattr(ser, op_name)(skipna=skipna)

pandas/tests/extension/test_arrow.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
393393
# attribute "pyarrow_dtype"
394394
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
395395

396-
if (
397-
pa.types.is_string(pa_type)
398-
or pa.types.is_binary(pa_type)
399-
or pa.types.is_decimal(pa_type)
400-
):
396+
if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type):
401397
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
402398
return False
399+
elif pa.types.is_string(pa_type):
400+
if op_name == "cumprod":
401+
return False
403402
elif pa.types.is_boolean(pa_type):
404403
if op_name in ["cumprod", "cummax", "cummin"]:
405404
return False
@@ -414,6 +413,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
414413
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
415414
pa_type = data.dtype.pyarrow_dtype
416415
op_name = all_numeric_accumulations
416+
417+
if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]:
418+
# https://github.com/pandas-dev/pandas/pull/60633
419+
# Doesn't fit test structure, tested in series/test_cumulative.py instead.
420+
return
421+
417422
ser = pd.Series(data)
418423

419424
if not self._supports_accumulation(ser, op_name):

pandas/tests/extension/test_string.py

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
from pandas.compat import HAS_PYARROW
2626

27+
from pandas.core.dtypes.base import StorageExtensionDtype
28+
2729
import pandas as pd
2830
import pandas._testing as tm
2931
from pandas.api.types import is_string_dtype
@@ -192,6 +194,14 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
192194
and op_name in ("any", "all")
193195
)
194196

197+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
198+
assert isinstance(ser.dtype, StorageExtensionDtype)
199+
return ser.dtype.storage == "pyarrow" and op_name in [
200+
"cummin",
201+
"cummax",
202+
"cumsum",
203+
]
204+
195205
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
196206
dtype = cast(StringDtype, tm.get_dtype(obj))
197207
if op_name in ["__add__", "__radd__"]:

pandas/tests/series/test_cumulative.py

+54
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
tests.frame.test_cumulative
77
"""
88

9+
import re
10+
911
import numpy as np
1012
import pytest
1113

@@ -227,3 +229,55 @@ def test_cumprod_timedelta(self):
227229
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)])
228230
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
229231
ser.cumprod()
232+
233+
@pytest.mark.parametrize(
234+
"data, op, skipna, expected_data",
235+
[
236+
([], "cumsum", True, []),
237+
([], "cumsum", False, []),
238+
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]),
239+
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]),
240+
(["x", pd.NA, "y"], "cumsum", True, ["x", pd.NA, "xy"]),
241+
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]),
242+
([pd.NA, "x", "y"], "cumsum", True, [pd.NA, "x", "xy"]),
243+
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
244+
([pd.NA, pd.NA, pd.NA], "cumsum", True, [pd.NA, pd.NA, pd.NA]),
245+
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
246+
([], "cummin", True, []),
247+
([], "cummin", False, []),
248+
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
249+
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]),
250+
(["y", pd.NA, "x"], "cummin", True, ["y", pd.NA, "x"]),
251+
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]),
252+
([pd.NA, "y", "x"], "cummin", True, [pd.NA, "y", "x"]),
253+
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]),
254+
([pd.NA, pd.NA, pd.NA], "cummin", True, [pd.NA, pd.NA, pd.NA]),
255+
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]),
256+
([], "cummax", True, []),
257+
([], "cummax", False, []),
258+
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]),
259+
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]),
260+
(["x", pd.NA, "y"], "cummax", True, ["x", pd.NA, "y"]),
261+
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]),
262+
([pd.NA, "x", "y"], "cummax", True, [pd.NA, "x", "y"]),
263+
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]),
264+
([pd.NA, pd.NA, pd.NA], "cummax", True, [pd.NA, pd.NA, pd.NA]),
265+
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
266+
],
267+
)
268+
def test_cum_methods_pyarrow_strings(
269+
self, pyarrow_string_dtype, data, op, skipna, expected_data
270+
):
271+
# https://github.com/pandas-dev/pandas/pull/60633
272+
ser = pd.Series(data, dtype=pyarrow_string_dtype)
273+
method = getattr(ser, op)
274+
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
275+
result = method(skipna=skipna)
276+
tm.assert_series_equal(result, expected)
277+
278+
def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna):
279+
# https://github.com/pandas-dev/pandas/pull/60633
280+
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype)
281+
msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'")
282+
with pytest.raises(TypeError, match=msg):
283+
ser.cumprod(skipna=skipna)

0 commit comments

Comments
 (0)