Skip to content

Commit bd8d56e

Browse files
rhshadrachjorisvandenbossche
authored andcommitted
ENH: Implement cum* methods for PyArrow strings (pandas-dev#60633)
* ENH: Implement cum* methods for PyArrow strings * cleanup * Cleanup * fixup * Fix extension tests * xfail test when there is no pyarrow * mypy fixups * Change logic & whatsnew * Change logic & whatsnew * Fix fixture * Fixup (cherry picked from commit b5d4e89)
1 parent 36d34a1 commit bd8d56e

File tree

8 files changed

+157
-10
lines changed

8 files changed

+157
-10
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ Other enhancements
3535
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
3636
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
3737
updated to raise FutureWarning with NumPy >= 2 (:issue:`60340`)
38+
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
3839
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
39-
-
4040

4141
.. ---------------------------------------------------------------------------
4242
.. _whatsnew_230.notable_bug_fixes:

pandas/conftest.py

+16
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,22 @@ def nullable_string_dtype(request):
12731273
return request.param
12741274

12751275

1276+
@pytest.fixture(
1277+
params=[
1278+
pytest.param(("pyarrow", np.nan), marks=td.skip_if_no("pyarrow")),
1279+
pytest.param(("pyarrow", pd.NA), marks=td.skip_if_no("pyarrow")),
1280+
]
1281+
)
1282+
def pyarrow_string_dtype(request):
1283+
"""
1284+
Parametrized fixture for string dtypes backed by Pyarrow.
1285+
1286+
* 'str[pyarrow]'
1287+
* 'string[pyarrow]'
1288+
"""
1289+
return pd.StringDtype(*request.param)
1290+
1291+
12761292
@pytest.fixture(
12771293
params=[
12781294
"python",

pandas/core/arrays/arrow/array.py

+55
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
is_list_like,
4646
is_numeric_dtype,
4747
is_scalar,
48+
is_string_dtype,
4849
)
4950
from pandas.core.dtypes.dtypes import DatetimeTZDtype
5051
from pandas.core.dtypes.missing import isna
@@ -1617,6 +1618,9 @@ def _accumulate(
16171618
------
16181619
NotImplementedError : subclass does not define accumulations
16191620
"""
1621+
if is_string_dtype(self):
1622+
return self._str_accumulate(name=name, skipna=skipna, **kwargs)
1623+
16201624
pyarrow_name = {
16211625
"cummax": "cumulative_max",
16221626
"cummin": "cumulative_min",
@@ -1652,6 +1656,57 @@ def _accumulate(
16521656

16531657
return type(self)(result)
16541658

1659+
def _str_accumulate(
1660+
self, name: str, *, skipna: bool = True, **kwargs
1661+
) -> ArrowExtensionArray | ExtensionArray:
1662+
"""
1663+
Accumulate implementation for strings, see `_accumulate` docstring for details.
1664+
1665+
pyarrow.compute does not implement these methods for strings.
1666+
"""
1667+
if name == "cumprod":
1668+
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
1669+
raise TypeError(msg)
1670+
1671+
# We may need to strip out trailing NA values
1672+
tail: pa.array | None = None
1673+
na_mask: pa.array | None = None
1674+
pa_array = self._pa_array
1675+
np_func = {
1676+
"cumsum": np.cumsum,
1677+
"cummin": np.minimum.accumulate,
1678+
"cummax": np.maximum.accumulate,
1679+
}[name]
1680+
1681+
if self._hasna:
1682+
na_mask = pc.is_null(pa_array)
1683+
if pc.all(na_mask) == pa.scalar(True):
1684+
return type(self)(pa_array)
1685+
if skipna:
1686+
if name == "cumsum":
1687+
pa_array = pc.fill_null(pa_array, "")
1688+
else:
1689+
# We can retain the running min/max by forward/backward filling.
1690+
pa_array = pc.fill_null_forward(pa_array)
1691+
pa_array = pc.fill_null_backward(pa_array)
1692+
else:
1693+
# When not skipping NA values, the result should be null from
1694+
# the first NA value onward.
1695+
idx = pc.index(na_mask, True).as_py()
1696+
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type)
1697+
pa_array = pa_array[:idx]
1698+
1699+
# error: Cannot call function of unknown type
1700+
pa_result = pa.array(np_func(pa_array), type=pa_array.type) # type: ignore[operator]
1701+
1702+
if tail is not None:
1703+
pa_result = pa.concat_arrays([pa_result, tail])
1704+
elif na_mask is not None:
1705+
pa_result = pc.if_else(na_mask, None, pa_result)
1706+
1707+
result = type(self)(pa_result)
1708+
return result
1709+
16551710
def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar:
16561711
"""
16571712
Return a pyarrow scalar result of performing the reduction operation.

pandas/tests/apply/test_str.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import pytest
66

7+
from pandas.compat import HAS_PYARROW
8+
79
from pandas.core.dtypes.common import is_number
810

911
from pandas import (
@@ -170,10 +172,14 @@ def test_agg_cython_table_transform_series(request, series, func, expected):
170172
# GH21224
171173
# test transforming functions in
172174
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
173-
if series.dtype == "string" and func in ("cumsum", np.cumsum, np.nancumsum):
175+
if (
176+
series.dtype == "string"
177+
and func in ("cumsum", np.cumsum, np.nancumsum)
178+
and not HAS_PYARROW
179+
):
174180
request.applymarker(
175181
pytest.mark.xfail(
176-
raises=(TypeError, NotImplementedError),
182+
raises=NotImplementedError,
177183
reason="TODO(infer_string) cumsum not yet implemented for string",
178184
)
179185
)

pandas/tests/extension/base/accumulate.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
1818
def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
1919
try:
2020
alt = ser.astype("float64")
21-
except TypeError:
22-
# e.g. Period can't be cast to float64
21+
except (TypeError, ValueError):
22+
# e.g. Period can't be cast to float64 (TypeError)
23+
# String can't be cast to float64 (ValueError)
2324
alt = ser.astype(object)
2425

2526
result = getattr(ser, op_name)(skipna=skipna)

pandas/tests/extension/test_arrow.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -388,13 +388,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
388388
# attribute "pyarrow_dtype"
389389
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
390390

391-
if (
392-
pa.types.is_string(pa_type)
393-
or pa.types.is_binary(pa_type)
394-
or pa.types.is_decimal(pa_type)
395-
):
391+
if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type):
396392
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
397393
return False
394+
elif pa.types.is_string(pa_type):
395+
if op_name == "cumprod":
396+
return False
398397
elif pa.types.is_boolean(pa_type):
399398
if op_name in ["cumprod", "cummax", "cummin"]:
400399
return False
@@ -409,6 +408,12 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
409408
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
410409
pa_type = data.dtype.pyarrow_dtype
411410
op_name = all_numeric_accumulations
411+
412+
if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]:
413+
# https://github.com/pandas-dev/pandas/pull/60633
414+
# Doesn't fit test structure, tested in series/test_cumulative.py instead.
415+
return
416+
412417
ser = pd.Series(data)
413418

414419
if not self._supports_accumulation(ser, op_name):

pandas/tests/extension/test_string.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from pandas.compat import HAS_PYARROW
2525

26+
from pandas.core.dtypes.base import StorageExtensionDtype
27+
2628
import pandas as pd
2729
import pandas._testing as tm
2830
from pandas.api.types import is_string_dtype
@@ -196,6 +198,14 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
196198
and op_name in ("any", "all")
197199
)
198200

201+
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
202+
assert isinstance(ser.dtype, StorageExtensionDtype)
203+
return ser.dtype.storage == "pyarrow" and op_name in [
204+
"cummin",
205+
"cummax",
206+
"cumsum",
207+
]
208+
199209
def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
200210
dtype = cast(StringDtype, tm.get_dtype(obj))
201211
if op_name in ["__add__", "__radd__"]:

pandas/tests/series/test_cumulative.py

+54
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
tests.frame.test_cumulative
77
"""
88

9+
import re
10+
911
import numpy as np
1012
import pytest
1113

@@ -155,3 +157,55 @@ def test_cumprod_timedelta(self):
155157
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)])
156158
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
157159
ser.cumprod()
160+
161+
@pytest.mark.parametrize(
162+
"data, op, skipna, expected_data",
163+
[
164+
([], "cumsum", True, []),
165+
([], "cumsum", False, []),
166+
(["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]),
167+
(["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]),
168+
(["x", pd.NA, "y"], "cumsum", True, ["x", pd.NA, "xy"]),
169+
(["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]),
170+
([pd.NA, "x", "y"], "cumsum", True, [pd.NA, "x", "xy"]),
171+
([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
172+
([pd.NA, pd.NA, pd.NA], "cumsum", True, [pd.NA, pd.NA, pd.NA]),
173+
([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]),
174+
([], "cummin", True, []),
175+
([], "cummin", False, []),
176+
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
177+
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]),
178+
(["y", pd.NA, "x"], "cummin", True, ["y", pd.NA, "x"]),
179+
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]),
180+
([pd.NA, "y", "x"], "cummin", True, [pd.NA, "y", "x"]),
181+
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]),
182+
([pd.NA, pd.NA, pd.NA], "cummin", True, [pd.NA, pd.NA, pd.NA]),
183+
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]),
184+
([], "cummax", True, []),
185+
([], "cummax", False, []),
186+
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]),
187+
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]),
188+
(["x", pd.NA, "y"], "cummax", True, ["x", pd.NA, "y"]),
189+
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]),
190+
([pd.NA, "x", "y"], "cummax", True, [pd.NA, "x", "y"]),
191+
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]),
192+
([pd.NA, pd.NA, pd.NA], "cummax", True, [pd.NA, pd.NA, pd.NA]),
193+
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
194+
],
195+
)
196+
def test_cum_methods_pyarrow_strings(
197+
self, pyarrow_string_dtype, data, op, skipna, expected_data
198+
):
199+
# https://github.com/pandas-dev/pandas/pull/60633
200+
ser = pd.Series(data, dtype=pyarrow_string_dtype)
201+
method = getattr(ser, op)
202+
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
203+
result = method(skipna=skipna)
204+
tm.assert_series_equal(result, expected)
205+
206+
def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna):
207+
# https://github.com/pandas-dev/pandas/pull/60633
208+
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype)
209+
msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'")
210+
with pytest.raises(TypeError, match=msg):
211+
ser.cumprod(skipna=skipna)

0 commit comments

Comments
 (0)