Skip to content

Commit cf7f0af

Browse files
authored
ENH: add cummax/cummin/cumprod support for arrow dtypes (#54574)
* ENH: add cumax, cumin, cumprod support to ArrowExtensionArray * whatsnew * move whatsnew to 2.1.0
1 parent 4efc97c commit cf7f0af

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ Other enhancements
265265
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to ``lzma.LZMAFile`` (:issue:`52979`)
266266
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
267267
- :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`)
268+
- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`)
268269
- Added support for the DataFrame Consortium Standard (:issue:`54383`)
269270
- Performance improvement in :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` (:issue:`51722`)
270271

pandas/core/arrays/arrow/array.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,9 @@ def _accumulate(
13891389
NotImplementedError : subclass does not define accumulations
13901390
"""
13911391
pyarrow_name = {
1392+
"cummax": "cumulative_max",
1393+
"cummin": "cumulative_min",
1394+
"cumprod": "cumulative_prod_checked",
13921395
"cumsum": "cumulative_sum_checked",
13931396
}.get(name, name)
13941397
pyarrow_meth = getattr(pc, pyarrow_name, None)
@@ -1398,12 +1401,20 @@ def _accumulate(
13981401
data_to_accum = self._pa_array
13991402

14001403
pa_dtype = data_to_accum.type
1401-
if pa.types.is_duration(pa_dtype):
1402-
data_to_accum = data_to_accum.cast(pa.int64())
1404+
1405+
convert_to_int = (
1406+
pa.types.is_temporal(pa_dtype) and name in ["cummax", "cummin"]
1407+
) or (pa.types.is_duration(pa_dtype) and name == "cumsum")
1408+
1409+
if convert_to_int:
1410+
if pa_dtype.bit_width == 32:
1411+
data_to_accum = data_to_accum.cast(pa.int32())
1412+
else:
1413+
data_to_accum = data_to_accum.cast(pa.int64())
14031414

14041415
result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)
14051416

1406-
if pa.types.is_duration(pa_dtype):
1417+
if convert_to_int:
14071418
result = result.cast(pa_dtype)
14081419

14091420
return type(self)(result)

pandas/tests/extension/test_arrow.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,15 @@ def test_from_sequence_of_strings_pa_array(self, data, request):
339339
def check_accumulate(self, ser, op_name, skipna):
340340
result = getattr(ser, op_name)(skipna=skipna)
341341

342-
if ser.dtype.kind == "m":
342+
pa_type = ser.dtype.pyarrow_dtype
343+
if pa.types.is_temporal(pa_type):
343344
# Just check that we match the integer behavior.
344-
ser = ser.astype("int64[pyarrow]")
345-
result = result.astype("int64[pyarrow]")
345+
if pa_type.bit_width == 32:
346+
int_type = "int32[pyarrow]"
347+
else:
348+
int_type = "int64[pyarrow]"
349+
ser = ser.astype(int_type)
350+
result = result.astype(int_type)
346351

347352
result = result.astype("Float64")
348353
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
@@ -353,14 +358,20 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
353358
# attribute "pyarrow_dtype"
354359
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]
355360

356-
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
357-
if op_name in ["cumsum", "cumprod"]:
361+
if (
362+
pa.types.is_string(pa_type)
363+
or pa.types.is_binary(pa_type)
364+
or pa.types.is_decimal(pa_type)
365+
):
366+
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
358367
return False
359-
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
360-
if op_name in ["cumsum", "cumprod"]:
368+
elif pa.types.is_boolean(pa_type):
369+
if op_name in ["cumprod", "cummax", "cummin"]:
361370
return False
362-
elif pa.types.is_duration(pa_type):
363-
if op_name == "cumprod":
371+
elif pa.types.is_temporal(pa_type):
372+
if op_name == "cumsum" and not pa.types.is_duration(pa_type):
373+
return False
374+
elif op_name == "cumprod":
364375
return False
365376
return True
366377

@@ -376,7 +387,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
376387
data, all_numeric_accumulations, skipna
377388
)
378389

379-
if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
390+
if pa_version_under9p0 or (
391+
pa_version_under13p0 and all_numeric_accumulations != "cumsum"
392+
):
380393
# xfailing takes a long time to run because pytest
381394
# renders the exception messages even when not showing them
382395
opt = request.config.option

0 commit comments

Comments
 (0)