Skip to content

Commit e837689

Browse files
rhshadrachWillAyd
authored andcommitted
ENH: Implement cum* methods for PyArrow strings
1 parent 315b549 commit e837689

File tree

4 files changed

+140
-8
lines changed

4 files changed

+140
-8
lines changed

pandas/conftest.py

+20
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,26 @@ def nullable_string_dtype(request):
13171317
return request.param
13181318

13191319

1320+
@pytest.fixture(
1321+
params=[
1322+
pytest.param(
1323+
pd.StringDtype("pyarrow", na_value=np.nan), marks=td.skip_if_no("pyarrow")
1324+
),
1325+
pytest.param(
1326+
pd.StringDtype("pyarrow", na_value=pd.NA), marks=td.skip_if_no("pyarrow")
1327+
),
1328+
]
1329+
)
1330+
def pyarrow_string_dtype(request):
1331+
"""
1332+
Parametrized fixture for string dtypes backed by Pyarrow.
1333+
1334+
* 'pd.StringDtype("pyarrow", na_value=np.nan)'
1335+
* 'pd.StringDtype("pyarrow", na_value=pd.NA)'
1336+
"""
1337+
return request.param
1338+
1339+
13201340
@pytest.fixture(
13211341
params=[
13221342
"python",

pandas/core/arrays/arrow/array.py

+56
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
is_list_like,
4242
is_numeric_dtype,
4343
is_scalar,
44+
is_string_dtype,
4445
pandas_dtype,
4546
)
4647
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -1619,6 +1620,9 @@ def _accumulate(
16191620
------
16201621
NotImplementedError : subclass does not define accumulations
16211622
"""
1623+
if is_string_dtype(self):
1624+
return self._str_accumulate(name=name, skipna=skipna, **kwargs)
1625+
16221626
pyarrow_name = {
16231627
"cummax": "cumulative_max",
16241628
"cummin": "cumulative_min",
@@ -1654,6 +1658,58 @@ def _accumulate(
16541658

16551659
return type(self)(result)
16561660

1661+
def _str_accumulate(
1662+
self, name: str, *, skipna: bool = True, **kwargs
1663+
) -> ArrowExtensionArray | ExtensionArray:
1664+
"""
1665+
Accumulate implementation for strings, see `_accumulate` docstring for details.
1666+
1667+
pyarrow.compute does not implement these methods for strings.
1668+
"""
1669+
if name == "cumprod":
1670+
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
1671+
raise TypeError(msg)
1672+
1673+
# When present and skipna is False, we stop of at the first NA value.
1674+
# as the tail becomes all NA values.
1675+
head: pa.array | None = None
1676+
tail: pa.array | None = None
1677+
pa_array = self._pa_array
1678+
np_func = {
1679+
"cumsum": np.cumsum,
1680+
"cummin": np.minimum.accumulate,
1681+
"cummax": np.maximum.accumulate,
1682+
}[name]
1683+
1684+
if self._hasna:
1685+
if skipna:
1686+
if name == "cumsum":
1687+
pa_array = pc.fill_null(pa_array, "")
1688+
else:
1689+
pa_array = pc.fill_null_forward(pa_array)
1690+
nulls = pc.is_null(pa_array)
1691+
idx = pc.index(nulls, False).as_py()
1692+
if idx == -1:
1693+
idx = len(pa_array)
1694+
if idx > 0:
1695+
head = pa.array([""] * idx, type=pa_array.type)
1696+
pa_array = pa_array[idx:].combine_chunks()
1697+
else:
1698+
nulls = pc.is_null(pa_array)
1699+
idx = pc.index(nulls, True).as_py()
1700+
tail = pa.nulls(len(pa_array) - idx, type=pa_array.type)
1701+
pa_array = pa_array[:idx].combine_chunks()
1702+
1703+
pa_result = pa.array(np_func(pa_array), type=pa_array.type)
1704+
1705+
if head is not None or tail is not None:
1706+
head = pa.array([], type=pa_array.type) if head is None else head
1707+
tail = pa.array([], type=pa_array.type) if tail is None else tail
1708+
pa_result = pa.concat_arrays([head, pa_result, tail])
1709+
1710+
result = type(self)(pa_result)
1711+
return result
1712+
16571713
def _reduce_pyarrow(self, name: str, *, skipna: bool = True, **kwargs) -> pa.Scalar:
16581714
"""
16591715
Return a pyarrow scalar result of performing the reduction operation.

pandas/tests/apply/test_str.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,10 @@ def test_agg_cython_table_series(series, func, expected):
159159
),
160160
),
161161
)
162-
def test_agg_cython_table_transform_series(request, series, func, expected):
162+
def test_agg_cython_table_transform_series(series, func, expected):
163163
# GH21224
164164
# test transforming functions in
165165
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
166-
if series.dtype == "string" and func == "cumsum":
167-
request.applymarker(
168-
pytest.mark.xfail(
169-
raises=(TypeError, NotImplementedError),
170-
reason="TODO(infer_string) cumsum not yet implemented for string",
171-
)
172-
)
173166
warn = None if isinstance(func, str) else FutureWarning
174167
with tm.assert_produces_warning(warn, match="is currently using Series.*"):
175168
result = series.agg(func)

pandas/tests/series/test_cumulative.py

+63
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,66 @@ def test_cumprod_timedelta(self):
227227
ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)])
228228
with pytest.raises(TypeError, match="cumprod not supported for Timedelta"):
229229
ser.cumprod()
230+
231+
@pytest.mark.parametrize(
232+
"data, skipna, expected_data",
233+
[
234+
([], True, []),
235+
([], False, []),
236+
(["x", "z", "y"], True, ["x", "xz", "xzy"]),
237+
(["x", "z", "y"], False, ["x", "xz", "xzy"]),
238+
(["x", pd.NA, "y"], True, ["x", "x", "xy"]),
239+
(["x", pd.NA, "y"], False, ["x", pd.NA, pd.NA]),
240+
([pd.NA, pd.NA, pd.NA], True, ["", "", ""]),
241+
([pd.NA, pd.NA, pd.NA], False, [pd.NA, pd.NA, pd.NA]),
242+
],
243+
)
244+
def test_cumsum_pyarrow_strings(
245+
self, pyarrow_string_dtype, data, skipna, expected_data
246+
):
247+
ser = pd.Series(data, dtype=pyarrow_string_dtype)
248+
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
249+
result = ser.cumsum(skipna=skipna)
250+
tm.assert_series_equal(result, expected)
251+
252+
@pytest.mark.parametrize(
253+
"data, op, skipna, expected_data",
254+
[
255+
([], "cummin", True, []),
256+
([], "cummin", False, []),
257+
(["y", "z", "x"], "cummin", True, ["y", "y", "x"]),
258+
(["y", "z", "x"], "cummin", False, ["y", "y", "x"]),
259+
(["y", pd.NA, "x"], "cummin", True, ["y", "y", "x"]),
260+
(["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]),
261+
([pd.NA, "y", "x"], "cummin", True, ["", "y", "x"]),
262+
([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]),
263+
([pd.NA, pd.NA, pd.NA], "cummin", True, ["", "", ""]),
264+
([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]),
265+
([], "cummax", True, []),
266+
([], "cummax", False, []),
267+
(["x", "z", "y"], "cummax", True, ["x", "z", "z"]),
268+
(["x", "z", "y"], "cummax", False, ["x", "z", "z"]),
269+
(["x", pd.NA, "y"], "cummax", True, ["x", "x", "y"]),
270+
(["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]),
271+
([pd.NA, "x", "y"], "cummax", True, ["", "x", "y"]),
272+
([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]),
273+
([pd.NA, pd.NA, pd.NA], "cummax", True, ["", "", ""]),
274+
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
275+
],
276+
)
277+
def test_cummin_cummax_pyarrow_strings(
278+
self, pyarrow_string_dtype, data, op, skipna, expected_data
279+
):
280+
ser = pd.Series(data, dtype=pyarrow_string_dtype)
281+
if expected_data is None:
282+
expected_data = ser.dtype.na_value
283+
method = getattr(ser, op)
284+
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
285+
result = method(skipna=skipna)
286+
tm.assert_series_equal(result, expected)
287+
288+
def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna):
289+
ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype)
290+
msg = f"operation 'cumprod' not supported for dtype '{ser.dtype}'"
291+
with pytest.raises(TypeError, match=msg):
292+
ser.cumprod(skipna=skipna)

0 commit comments

Comments
 (0)