Skip to content

Commit 0335d39

Browse files
Backport PR #54574 on branch 2.1.x (ENH: add cummax/cummin/cumprod support for arrow dtypes) (#54603)
Backport PR #54574: ENH: add cummax/cummin/cumprod support for arrow dtypes Co-authored-by: Luke Manley <[email protected]>
1 parent 005d876 commit 0335d39

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ Other enhancements
265265
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to ``lzma.LZMAFile`` (:issue:`52979`)
266266
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
267267
- :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`)
268+
- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`)
268269
- Added support for the DataFrame Consortium Standard (:issue:`54383`)
269270
- Performance improvement in :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` (:issue:`51722`)
270271

pandas/core/arrays/arrow/array.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,9 @@ def _accumulate(
13891389
NotImplementedError : subclass does not define accumulations
13901390
"""
13911391
pyarrow_name = {
1392+
"cummax": "cumulative_max",
1393+
"cummin": "cumulative_min",
1394+
"cumprod": "cumulative_prod_checked",
13921395
"cumsum": "cumulative_sum_checked",
13931396
}.get(name, name)
13941397
pyarrow_meth = getattr(pc, pyarrow_name, None)
@@ -1398,12 +1401,20 @@ def _accumulate(
13981401
data_to_accum = self._pa_array
13991402

14001403
pa_dtype = data_to_accum.type
1401-
if pa.types.is_duration(pa_dtype):
1402-
data_to_accum = data_to_accum.cast(pa.int64())
1404+
1405+
convert_to_int = (
1406+
pa.types.is_temporal(pa_dtype) and name in ["cummax", "cummin"]
1407+
) or (pa.types.is_duration(pa_dtype) and name == "cumsum")
1408+
1409+
if convert_to_int:
1410+
if pa_dtype.bit_width == 32:
1411+
data_to_accum = data_to_accum.cast(pa.int32())
1412+
else:
1413+
data_to_accum = data_to_accum.cast(pa.int64())
14031414

14041415
result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)
14051416

1406-
if pa.types.is_duration(pa_dtype):
1417+
if convert_to_int:
14071418
result = result.cast(pa_dtype)
14081419

14091420
return type(self)(result)

pandas/tests/extension/test_arrow.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,15 @@ class TestBaseAccumulateTests(base.BaseAccumulateTests):
347347
def check_accumulate(self, ser, op_name, skipna):
348348
result = getattr(ser, op_name)(skipna=skipna)
349349

350-
if ser.dtype.kind == "m":
350+
pa_type = ser.dtype.pyarrow_dtype
351+
if pa.types.is_temporal(pa_type):
351352
# Just check that we match the integer behavior.
352-
ser = ser.astype("int64[pyarrow]")
353-
result = result.astype("int64[pyarrow]")
353+
if pa_type.bit_width == 32:
354+
int_type = "int32[pyarrow]"
355+
else:
356+
int_type = "int64[pyarrow]"
357+
ser = ser.astype(int_type)
358+
result = result.astype(int_type)
354359

355360
result = result.astype("Float64")
356361
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
@@ -361,14 +366,20 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
361366
# attribute "pyarrow_dtype"
362367
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
363368

364-
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
365-
if op_name in ["cumsum", "cumprod"]:
369+
if (
370+
pa.types.is_string(pa_type)
371+
or pa.types.is_binary(pa_type)
372+
or pa.types.is_decimal(pa_type)
373+
):
374+
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
366375
return False
367-
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
368-
if op_name in ["cumsum", "cumprod"]:
376+
elif pa.types.is_boolean(pa_type):
377+
if op_name in ["cumprod", "cummax", "cummin"]:
369378
return False
370-
elif pa.types.is_duration(pa_type):
371-
if op_name == "cumprod":
379+
elif pa.types.is_temporal(pa_type):
380+
if op_name == "cumsum" and not pa.types.is_duration(pa_type):
381+
return False
382+
elif op_name == "cumprod":
372383
return False
373384
return True
374385

@@ -384,7 +395,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
384395
data, all_numeric_accumulations, skipna
385396
)
386397

387-
if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
398+
if pa_version_under9p0 or (
399+
pa_version_under13p0 and all_numeric_accumulations != "cumsum"
400+
):
388401
# xfailing takes a long time to run because pytest
389402
# renders the exception messages even when not showing them
390403
opt = request.config.option

0 commit comments

Comments
 (0)