diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index d1a689dc60830..8a9f786fa87b2 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -264,6 +264,7 @@ Other enhancements - Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to lzma.LZMAFile (:issue:`52979`) - Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype objects (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`) - :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`) +- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`) - Added support for the DataFrame Consortium Standard (:issue:`54383`) - Performance improvement in :meth:`.GroupBy.quantile` (:issue:`51722`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 0f46e5a4e7482..3c65e6b4879e2 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1389,6 +1389,9 @@ def _accumulate( NotImplementedError : subclass does not define accumulations """ pyarrow_name = { + "cummax": "cumulative_max", + "cummin": "cumulative_min", + "cumprod": "cumulative_prod_checked", "cumsum": "cumulative_sum_checked", }.get(name, name) pyarrow_meth = getattr(pc, pyarrow_name, None) @@ -1398,12 +1401,20 @@ def _accumulate( data_to_accum = self._pa_array pa_dtype = data_to_accum.type - if pa.types.is_duration(pa_dtype): - data_to_accum = data_to_accum.cast(pa.int64()) + + convert_to_int = ( + pa.types.is_temporal(pa_dtype) and name in ["cummax", "cummin"] + ) or (pa.types.is_duration(pa_dtype) and name == "cumsum") + + if convert_to_int: + if pa_dtype.bit_width == 32: + data_to_accum = data_to_accum.cast(pa.int32()) + else: + data_to_accum = data_to_accum.cast(pa.int64()) result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs) - if pa.types.is_duration(pa_dtype): + if convert_to_int: result = result.cast(pa_dtype) return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 4c05049ddfcf5..01dedc86c8dca 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -339,10 +339,15 @@ def test_from_sequence_of_strings_pa_array(self, data, request): def check_accumulate(self, ser, op_name, skipna): result = getattr(ser, op_name)(skipna=skipna) - if ser.dtype.kind == "m": + pa_type = ser.dtype.pyarrow_dtype + if pa.types.is_temporal(pa_type): # Just check that we match the integer behavior. - ser = ser.astype("int64[pyarrow]") - result = result.astype("int64[pyarrow]") + if pa_type.bit_width == 32: + int_type = "int32[pyarrow]" + else: + int_type = "int64[pyarrow]" + ser = ser.astype(int_type) + result = result.astype(int_type) result = result.astype("Float64") expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna) @@ -353,14 +358,20 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: # attribute "pyarrow_dtype" pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr] - if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type): - if op_name in ["cumsum", "cumprod"]: + if ( + pa.types.is_string(pa_type) + or pa.types.is_binary(pa_type) + or pa.types.is_decimal(pa_type) + ): + if op_name in ["cumsum", "cumprod", "cummax", "cummin"]: return False - elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type): - if op_name in ["cumsum", "cumprod"]: + elif pa.types.is_boolean(pa_type): + if op_name in ["cumprod", "cummax", "cummin"]: return False - elif pa.types.is_duration(pa_type): - if op_name == "cumprod": + elif pa.types.is_temporal(pa_type): + if op_name == "cumsum" and not pa.types.is_duration(pa_type): + return False + elif op_name == "cumprod": return False return True @@ -376,7 +387,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques data, all_numeric_accumulations, skipna ) - if all_numeric_accumulations != "cumsum" or pa_version_under9p0: + if pa_version_under9p0 or ( + pa_version_under13p0 and all_numeric_accumulations != "cumsum" + ): # xfailing takes a long time to run because pytest # renders the exception messages even when not showing them opt = request.config.option