Skip to content

Commit 9918c84

Browse files
authored
ENH: Add cumsum to ArrowExtensionArray (#50389)
1 parent 1b653b1 commit 9918c84

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

pandas/core/arrays/arrow/array.py

+39
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,45 @@ def _concat_same_type(
853853
arr = pa.chunked_array(chunks)
854854
return cls(arr)
855855

856+
def _accumulate(
857+
self, name: str, *, skipna: bool = True, **kwargs
858+
) -> ArrowExtensionArray | ExtensionArray:
859+
"""
860+
Return an ExtensionArray performing an accumulation operation.
861+
862+
The underlying data type might change.
863+
864+
Parameters
865+
----------
866+
name : str
867+
Name of the function, supported values are:
868+
- cummin
869+
- cummax
870+
- cumsum
871+
- cumprod
872+
skipna : bool, default True
873+
If True, skip NA values.
874+
**kwargs
875+
Additional keyword arguments passed to the accumulation function.
876+
Currently, there is no supported kwarg.
877+
878+
Returns
879+
-------
880+
array
881+
882+
Raises
883+
------
884+
NotImplementedError : subclass does not define accumulations
885+
"""
886+
pyarrow_name = {
887+
"cumsum": "cumulative_sum_checked",
888+
}.get(name, name)
889+
pyarrow_meth = getattr(pc, pyarrow_name, None)
890+
if pyarrow_meth is None:
891+
return super()._accumulate(name, skipna=skipna, **kwargs)
892+
result = pyarrow_meth(self._data, skip_nulls=skipna, **kwargs)
893+
return type(self)(result)
894+
856895
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
857896
"""
858897
Return a scalar result of performing the reduction operation.

pandas/tests/extension/test_arrow.py

+48
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,54 @@ def test_getitem_scalar(self, data):
343343
super().test_getitem_scalar(data)
344344

345345

346+
class TestBaseAccumulateTests(base.BaseAccumulateTests):
347+
def check_accumulate(self, s, op_name, skipna):
348+
result = getattr(s, op_name)(skipna=skipna).astype("Float64")
349+
expected = getattr(s.astype("Float64"), op_name)(skipna=skipna)
350+
self.assert_series_equal(result, expected, check_dtype=False)
351+
352+
@pytest.mark.parametrize("skipna", [True, False])
353+
def test_accumulate_series_raises(
354+
self, data, all_numeric_accumulations, skipna, request
355+
):
356+
pa_type = data.dtype.pyarrow_dtype
357+
if (
358+
(pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type))
359+
and all_numeric_accumulations == "cumsum"
360+
and not pa_version_under9p0
361+
):
362+
request.node.add_marker(
363+
pytest.mark.xfail(
364+
reason=f"{all_numeric_accumulations} implemented for {pa_type}"
365+
)
366+
)
367+
op_name = all_numeric_accumulations
368+
ser = pd.Series(data)
369+
370+
with pytest.raises(NotImplementedError):
371+
getattr(ser, op_name)(skipna=skipna)
372+
373+
@pytest.mark.parametrize("skipna", [True, False])
374+
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
375+
pa_type = data.dtype.pyarrow_dtype
376+
if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
377+
request.node.add_marker(
378+
pytest.mark.xfail(
379+
reason=f"{all_numeric_accumulations} not implemented",
380+
raises=NotImplementedError,
381+
)
382+
)
383+
elif not (pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type)):
384+
request.node.add_marker(
385+
pytest.mark.xfail(
386+
reason=f"{all_numeric_accumulations} not implemented for {pa_type}"
387+
)
388+
)
389+
op_name = all_numeric_accumulations
390+
ser = pd.Series(data)
391+
self.check_accumulate(ser, op_name, skipna)
392+
393+
346394
class TestBaseNumericReduce(base.BaseNumericReduceTests):
347395
def check_reduce(self, ser, op_name, skipna):
348396
pa_dtype = ser.dtype.pyarrow_dtype

0 commit comments

Comments
 (0)