From 0d30a5e7a6007508b967fa0e5f9987c2d0236cba Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 21 Dec 2022 16:57:09 -0800 Subject: [PATCH 1/3] ENH: Add cumsum to ArrowExtensionArray --- pandas/core/arrays/arrow/array.py | 39 +++++++++++++++++++++++ pandas/tests/extension/test_arrow.py | 46 ++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 1bbec97756e79..8e874cf726cc9 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -853,6 +853,45 @@ def _concat_same_type( arr = pa.chunked_array(chunks) return cls(arr) + def _accumulate( + self, name: str, *, skipna: bool = True, **kwargs + ) -> ArrowExtensionArrayT: + """ + Return an ExtensionArray performing an accumulation operation. + + The underlying data type might change. + + Parameters + ---------- + name : str + Name of the function, supported values are: + - cummin + - cummax + - cumsum + - cumprod + skipna : bool, default True + If True, skip NA values. + **kwargs + Additional keyword arguments passed to the accumulation function. + Currently, there is no supported kwarg. + + Returns + ------- + array + + Raises + ------ + NotImplementedError : subclass does not define accumulations + """ + pyarrow_name = { + "cumsum": "cumulative_sum_checked", + }.get(name, name) + pyarrow_meth = getattr(pc, pyarrow_name, None) + if pyarrow_meth is None: + return super()._accumulate(name, skipna=skipna, **kwargs) + result = pyarrow_meth(self._data, skip_nulls=skipna, **kwargs) + return type(self)(result) + def _reduce(self, name: str, *, skipna: bool = True, **kwargs): """ Return a scalar result of performing the reduction operation. diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index f93cf3d6bc138..bb1cda67d98d7 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -343,6 +343,52 @@ def test_getitem_scalar(self, data): super().test_getitem_scalar(data) +class TestBaseAccumulateTests(base.BaseAccumulateTests): + def check_accumulate(self, s, op_name, skipna): + result = getattr(s, op_name)(skipna=skipna).astype("Float64") + expected = getattr(s.astype("Float64"), op_name)(skipna=skipna) + self.assert_series_equal(result, expected, check_dtype=False) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_accumulate_series_raises( + self, data, all_numeric_accumulations, skipna, request + ): + pa_type = data.dtype.pyarrow_dtype + if ( + pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) + ) and all_numeric_accumulations == "cumsum": + request.node.add_marker( + pytest.mark.xfail( + reason=f"{all_numeric_accumulations} implemented for {pa_type}" + ) + ) + op_name = all_numeric_accumulations + ser = pd.Series(data) + + with pytest.raises(NotImplementedError): + getattr(ser, op_name)(skipna=skipna) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request): + pa_type = data.dtype.pyarrow_dtype + if all_numeric_accumulations != "cumsum": + request.node.add_marker( + pytest.mark.xfail( + reason=f"{all_numeric_accumulations} not implemented", + raises=NotImplementedError, + ) + ) + elif not (pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type)): + request.node.add_marker( + pytest.mark.xfail( + reason=f"{all_numeric_accumulations} not implemented for {pa_type}" + ) + ) + op_name = all_numeric_accumulations + ser = pd.Series(data) + self.check_accumulate(ser, op_name, skipna) + + class TestBaseNumericReduce(base.BaseNumericReduceTests): def check_reduce(self, ser, op_name, skipna): pa_dtype = ser.dtype.pyarrow_dtype From d1dae301b3a87e121c9bb2b5f727ac40ebbcc97f Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 22 Dec 2022 15:22:02 -0800 Subject: [PATCH 2/3] Address pa under 9 --- pandas/core/arrays/arrow/array.py | 4 ++-- pandas/tests/extension/test_arrow.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 8e874cf726cc9..781a3f9905cff 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -853,9 +853,9 @@ def _concat_same_type( arr = pa.chunked_array(chunks) return cls(arr) - def _accumulate( + def _accumulate( # type: ignore[override] self, name: str, *, skipna: bool = True, **kwargs - ) -> ArrowExtensionArrayT: + ) -> ArrowExtensionArray: """ Return an ExtensionArray performing an accumulation operation. diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index bb1cda67d98d7..9b42b86efd0d0 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -355,8 +355,10 @@ def test_accumulate_series_raises( ): pa_type = data.dtype.pyarrow_dtype if ( - pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type) - ) and all_numeric_accumulations == "cumsum": + (pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type)) + and all_numeric_accumulations == "cumsum" + and not pa_version_under9p0 + ): request.node.add_marker( pytest.mark.xfail( reason=f"{all_numeric_accumulations} implemented for {pa_type}" @@ -371,7 +373,7 @@ def test_accumulate_series_raises( @pytest.mark.parametrize("skipna", [True, False]) def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request): pa_type = data.dtype.pyarrow_dtype - if all_numeric_accumulations != "cumsum": + if all_numeric_accumulations != "cumsum" or pa_version_under9p0: request.node.add_marker( pytest.mark.xfail( reason=f"{all_numeric_accumulations} not implemented", From 386c071e4786e2266428ac44086d3af71aea9ec3 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 23 Dec 2022 15:25:57 -0800 Subject: [PATCH 3/3] Fix typing --- pandas/core/arrays/arrow/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 781a3f9905cff..6250c298f291f 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -853,9 +853,9 @@ def _concat_same_type( arr = pa.chunked_array(chunks) return cls(arr) - def _accumulate( # type: ignore[override] + def _accumulate( self, name: str, *, skipna: bool = True, **kwargs - ) -> ArrowExtensionArray: + ) -> ArrowExtensionArray | ExtensionArray: """ Return an ExtensionArray performing an accumulation operation.