Skip to content

Commit a38a24e

Browse files
authored
Use ea interface to calculate accumulator functions for datetimelike (#50297)
* Implement ea accumulate for datetimelike * Fix dtype cast * Remove from nanops * Add period tests * Move comment * Move comment * Address review * Dont retain freq * Add tests * Address review
1 parent b173e7b commit a38a24e

File tree

7 files changed

+169
-70
lines changed

7 files changed

+169
-70
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
datetimelke_accumulations.py is for accumulations of datetimelike extension arrays
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import Callable
8+
9+
import numpy as np
10+
11+
from pandas._libs import iNaT
12+
13+
from pandas.core.dtypes.missing import isna
14+
15+
16+
def _cum_func(
17+
func: Callable,
18+
values: np.ndarray,
19+
*,
20+
skipna: bool = True,
21+
):
22+
"""
23+
Accumulations for 1D datetimelike arrays.
24+
25+
Parameters
26+
----------
27+
func : np.cumsum, np.maximum.accumulate, np.minimum.accumulate
28+
values : np.ndarray
29+
Numpy array with the values (can be of any dtype that support the
30+
operation). Values is changed is modified inplace.
31+
skipna : bool, default True
32+
Whether to skip NA.
33+
"""
34+
try:
35+
fill_value = {
36+
np.maximum.accumulate: np.iinfo(np.int64).min,
37+
np.cumsum: 0,
38+
np.minimum.accumulate: np.iinfo(np.int64).max,
39+
}[func]
40+
except KeyError:
41+
raise ValueError(f"No accumulation for {func} implemented on BaseMaskedArray")
42+
43+
mask = isna(values)
44+
y = values.view("i8")
45+
y[mask] = fill_value
46+
47+
if not skipna:
48+
mask = np.maximum.accumulate(mask)
49+
50+
result = func(y)
51+
result[mask] = iNaT
52+
53+
if values.dtype.kind in ["m", "M"]:
54+
return result.view(values.dtype.base)
55+
return result
56+
57+
58+
def cumsum(values: np.ndarray, *, skipna: bool = True) -> np.ndarray:
59+
return _cum_func(np.cumsum, values, skipna=skipna)
60+
61+
62+
def cummin(values: np.ndarray, *, skipna: bool = True):
63+
return _cum_func(np.minimum.accumulate, values, skipna=skipna)
64+
65+
66+
def cummax(values: np.ndarray, *, skipna: bool = True):
67+
return _cum_func(np.maximum.accumulate, values, skipna=skipna)

pandas/core/arrays/datetimelike.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
isin,
122122
unique1d,
123123
)
124+
from pandas.core.array_algos import datetimelike_accumulations
124125
from pandas.core.arraylike import OpsMixin
125126
from pandas.core.arrays._mixins import (
126127
NDArrayBackedExtensionArray,
@@ -1292,25 +1293,15 @@ def _addsub_object_array(self, other: npt.NDArray[np.object_], op):
12921293
return res_values
12931294

12941295
def _accumulate(self, name: str, *, skipna: bool = True, **kwargs):
1296+
if name not in {"cummin", "cummax"}:
1297+
raise TypeError(f"Accumulation {name} not supported for {type(self)}")
12951298

1296-
if is_period_dtype(self.dtype):
1297-
data = self
1298-
else:
1299-
# Incompatible types in assignment (expression has type
1300-
# "ndarray[Any, Any]", variable has type "DatetimeLikeArrayMixin"
1301-
data = self._ndarray.copy() # type: ignore[assignment]
1302-
1303-
if name in {"cummin", "cummax"}:
1304-
func = np.minimum.accumulate if name == "cummin" else np.maximum.accumulate
1305-
result = cast(np.ndarray, nanops.na_accum_func(data, func, skipna=skipna))
1306-
1307-
# error: Unexpected keyword argument "freq" for
1308-
# "_simple_new" of "NDArrayBacked" [call-arg]
1309-
return type(self)._simple_new(
1310-
result, freq=self.freq, dtype=self.dtype # type: ignore[call-arg]
1311-
)
1299+
op = getattr(datetimelike_accumulations, name)
1300+
result = op(self.copy(), skipna=skipna, **kwargs)
13121301

1313-
raise TypeError(f"Accumulation {name} not supported for {type(self)}")
1302+
return type(self)._simple_new(
1303+
result, freq=None, dtype=self.dtype # type: ignore[call-arg]
1304+
)
13141305

13151306
@unpack_zerodim_and_defer("__add__")
13161307
def __add__(self, other):

pandas/core/arrays/timedeltas.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from pandas.core.dtypes.missing import isna
6464

6565
from pandas.core import nanops
66+
from pandas.core.array_algos import datetimelike_accumulations
6667
from pandas.core.arrays import datetimelike as dtl
6768
from pandas.core.arrays._ranges import generate_regular_range
6869
import pandas.core.common as com
@@ -418,12 +419,9 @@ def std(
418419
# Accumulations
419420

420421
def _accumulate(self, name: str, *, skipna: bool = True, **kwargs):
421-
422-
data = self._ndarray.copy()
423-
424422
if name == "cumsum":
425-
func = np.cumsum
426-
result = cast(np.ndarray, nanops.na_accum_func(data, func, skipna=skipna))
423+
op = getattr(datetimelike_accumulations, name)
424+
result = op(self._ndarray.copy(), skipna=skipna, **kwargs)
427425

428426
return type(self)._simple_new(result, freq=None, dtype=self.dtype)
429427
elif name == "cumprod":

pandas/core/nanops.py

+4-46
Original file line numberDiff line numberDiff line change
@@ -1715,53 +1715,11 @@ def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike:
17151715
np.minimum.accumulate: (np.inf, np.nan),
17161716
}[accum_func]
17171717

1718-
# We will be applying this function to block values
1719-
if values.dtype.kind in ["m", "M"]:
1720-
# GH#30460, GH#29058
1721-
# numpy 1.18 started sorting NaTs at the end instead of beginning,
1722-
# so we need to work around to maintain backwards-consistency.
1723-
orig_dtype = values.dtype
1724-
1725-
# We need to define mask before masking NaTs
1726-
mask = isna(values)
1727-
1728-
y = values.view("i8")
1729-
# Note: the accum_func comparison fails as an "is" comparison
1730-
changed = accum_func == np.minimum.accumulate
1731-
1732-
try:
1733-
if changed:
1734-
y[mask] = lib.i8max
1735-
1736-
result = accum_func(y, axis=0)
1737-
finally:
1738-
if changed:
1739-
# restore NaT elements
1740-
y[mask] = iNaT
1718+
# This should go through ea interface
1719+
assert values.dtype.kind not in ["m", "M"]
17411720

1742-
if skipna:
1743-
result[mask] = iNaT
1744-
elif accum_func == np.minimum.accumulate:
1745-
# Restore NaTs that we masked previously
1746-
nz = (~np.asarray(mask)).nonzero()[0]
1747-
if len(nz):
1748-
# everything up to the first non-na entry stays NaT
1749-
result[: nz[0]] = iNaT
1750-
1751-
if isinstance(values.dtype, np.dtype):
1752-
result = result.view(orig_dtype)
1753-
else:
1754-
# DatetimeArray/TimedeltaArray
1755-
# TODO: have this case go through a DTA method?
1756-
# For DatetimeTZDtype, view result as M8[ns]
1757-
npdtype = orig_dtype if isinstance(orig_dtype, np.dtype) else "M8[ns]"
1758-
# Item "type" of "Union[Type[ExtensionArray], Type[ndarray[Any, Any]]]"
1759-
# has no attribute "_simple_new"
1760-
result = type(values)._simple_new( # type: ignore[union-attr]
1761-
result.view(npdtype), dtype=orig_dtype
1762-
)
1763-
1764-
elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):
1721+
# We will be applying this function to block values
1722+
if skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):
17651723
vals = values.copy()
17661724
mask = isna(vals)
17671725
vals[mask] = mask_a
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
3+
import pandas._testing as tm
4+
from pandas.core.arrays import DatetimeArray
5+
6+
7+
class TestAccumulator:
8+
def test_accumulators_freq(self):
9+
# GH#50297
10+
arr = DatetimeArray._from_sequence_not_strict(
11+
[
12+
"2000-01-01",
13+
"2000-01-02",
14+
"2000-01-03",
15+
],
16+
freq="D",
17+
)
18+
result = arr._accumulate("cummin")
19+
expected = DatetimeArray._from_sequence_not_strict(
20+
["2000-01-01"] * 3, freq=None
21+
)
22+
tm.assert_datetime_array_equal(result, expected)
23+
24+
result = arr._accumulate("cummax")
25+
expected = DatetimeArray._from_sequence_not_strict(
26+
[
27+
"2000-01-01",
28+
"2000-01-02",
29+
"2000-01-03",
30+
],
31+
freq=None,
32+
)
33+
tm.assert_datetime_array_equal(result, expected)
34+
35+
@pytest.mark.parametrize("func", ["cumsum", "cumprod"])
36+
def test_accumulators_disallowed(self, func):
37+
# GH#50297
38+
arr = DatetimeArray._from_sequence_not_strict(
39+
[
40+
"2000-01-01",
41+
"2000-01-02",
42+
],
43+
freq="D",
44+
)
45+
with pytest.raises(TypeError, match=f"Accumulation {func}"):
46+
arr._accumulate(func)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
import pandas._testing as tm
4+
from pandas.core.arrays import TimedeltaArray
5+
6+
7+
class TestAccumulator:
8+
def test_accumulators_disallowed(self):
9+
# GH#50297
10+
arr = TimedeltaArray._from_sequence_not_strict(["1D", "2D"])
11+
with pytest.raises(TypeError, match="cumprod not supported"):
12+
arr._accumulate("cumprod")
13+
14+
def test_cumsum(self):
15+
# GH#50297
16+
arr = TimedeltaArray._from_sequence_not_strict(["1D", "2D"])
17+
result = arr._accumulate("cumsum")
18+
expected = TimedeltaArray._from_sequence_not_strict(["1D", "3D"])
19+
tm.assert_timedelta_array_equal(result, expected)

pandas/tests/series/test_cumulative.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ def test_cummin_cummax(self, datetime_series, method):
7070
[
7171
"cummax",
7272
False,
73-
["NaT", "2 days", "2 days", "2 days", "2 days", "3 days"],
73+
["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"],
7474
],
7575
[
7676
"cummin",
7777
False,
78-
["NaT", "2 days", "2 days", "1 days", "1 days", "1 days"],
78+
["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"],
7979
],
8080
],
8181
)
@@ -91,6 +91,26 @@ def test_cummin_cummax_datetimelike(self, ts, method, skipna, exp_tdi):
9191
result = getattr(ser, method)(skipna=skipna)
9292
tm.assert_series_equal(expected, result)
9393

94+
@pytest.mark.parametrize(
95+
"func, exp",
96+
[
97+
("cummin", pd.Period("2012-1-1", freq="D")),
98+
("cummax", pd.Period("2012-1-2", freq="D")),
99+
],
100+
)
101+
def test_cummin_cummax_period(self, func, exp):
102+
# GH#28385
103+
ser = pd.Series(
104+
[pd.Period("2012-1-1", freq="D"), pd.NaT, pd.Period("2012-1-2", freq="D")]
105+
)
106+
result = getattr(ser, func)(skipna=False)
107+
expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, pd.NaT])
108+
tm.assert_series_equal(result, expected)
109+
110+
result = getattr(ser, func)(skipna=True)
111+
expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, exp])
112+
tm.assert_series_equal(result, expected)
113+
94114
@pytest.mark.parametrize(
95115
"arg",
96116
[

0 commit comments

Comments
 (0)