Skip to content

Commit bf1d008

Browse files
authored
ENH: support cumsum with pyarrow durations (pandas-dev#50927)
1 parent 6d6917d commit bf1d008

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

pandas/core/arrays/arrow/array.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -985,7 +985,18 @@ def _accumulate(
985985
pyarrow_meth = getattr(pc, pyarrow_name, None)
986986
if pyarrow_meth is None:
987987
return super()._accumulate(name, skipna=skipna, **kwargs)
988-
result = pyarrow_meth(self._data, skip_nulls=skipna, **kwargs)
988+
989+
data_to_accum = self._data
990+
991+
pa_dtype = data_to_accum.type
992+
if pa.types.is_duration(pa_dtype):
993+
data_to_accum = data_to_accum.cast(pa.int64())
994+
995+
result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)
996+
997+
if pa.types.is_duration(pa_dtype):
998+
result = result.cast(pa_dtype)
999+
9891000
return type(self)(result)
9901001

9911002
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):

pandas/tests/extension/test_arrow.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -372,16 +372,27 @@ def test_getitem_scalar(self, data):
372372

373373

374374
class TestBaseAccumulateTests(base.BaseAccumulateTests):
375-
def check_accumulate(self, s, op_name, skipna):
376-
result = getattr(s, op_name)(skipna=skipna).astype("Float64")
377-
expected = getattr(s.astype("Float64"), op_name)(skipna=skipna)
375+
def check_accumulate(self, ser, op_name, skipna):
376+
result = getattr(ser, op_name)(skipna=skipna)
377+
378+
if ser.dtype.kind == "m":
379+
# Just check that we match the integer behavior.
380+
ser = ser.astype("int64[pyarrow]")
381+
result = result.astype("int64[pyarrow]")
382+
383+
result = result.astype("Float64")
384+
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
378385
self.assert_series_equal(result, expected, check_dtype=False)
379386

380387
@pytest.mark.parametrize("skipna", [True, False])
381388
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
382389
pa_type = data.dtype.pyarrow_dtype
383390
if (
384-
(pa.types.is_integer(pa_type) or pa.types.is_floating(pa_type))
391+
(
392+
pa.types.is_integer(pa_type)
393+
or pa.types.is_floating(pa_type)
394+
or pa.types.is_duration(pa_type)
395+
)
385396
and all_numeric_accumulations == "cumsum"
386397
and not pa_version_under9p0
387398
):
@@ -423,9 +434,7 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
423434
raises=NotImplementedError,
424435
)
425436
)
426-
elif all_numeric_accumulations == "cumsum" and (
427-
pa.types.is_duration(pa_type) or pa.types.is_boolean(pa_type)
428-
):
437+
elif all_numeric_accumulations == "cumsum" and (pa.types.is_boolean(pa_type)):
429438
request.node.add_marker(
430439
pytest.mark.xfail(
431440
reason=f"{all_numeric_accumulations} not implemented for {pa_type}",

0 commit comments

Comments
 (0)