Skip to content

Commit 6f708b2

Browse files
stevenschaererJulianWgs
authored andcommitted
BUG: various groupby ewm times issues (pandas-dev#40952)
* times in ewm groupby: sort times in according to grouping; add missing support for times in numba implementation; fix bug in cython implementation * add GH issue id to tests * fix typing validation error * PR comments * trying to fix int64 to int32 casting TypeError * PR comments * PR comments * PR comments
1 parent 3d4158c commit 6f708b2

File tree

7 files changed

+170
-18
lines changed

7 files changed

+170
-18
lines changed

doc/source/whatsnew/v1.3.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,9 @@ Groupby/resample/rolling
822822
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would incorrectly raise a ``ValueError`` when providing ``times`` (:issue:`40164`)
823823
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would not retain ``com``, ``span``, ``alpha`` or ``halflife`` attributes (:issue:`40164`)
824824
- :class:`core.window.ewm.ExponentialMovingWindow` now raises a ``NotImplementedError`` when specifying ``times`` with ``adjust=False`` due to an incorrect calculation (:issue:`40098`)
825+
- Bug in :meth:`core.window.ewm.ExponentialMovingWindowGroupby.mean` where the times argument was ignored when ``engine='numba'`` (:issue:`40951`)
826+
- Bug in :meth:`core.window.ewm.ExponentialMovingWindowGroupby.mean` where the wrong times were used in case of multiple groups (:issue:`40951`)
827+
- Bug in :class:`core.window.ewm.ExponentialMovingWindowGroupby` where the times vector and values became out of sync for non-trivial groups (:issue:`40951`)
825828
- Bug in :meth:`Series.asfreq` and :meth:`DataFrame.asfreq` dropping rows when the index is not sorted (:issue:`39805`)
826829
- Bug in aggregation functions for :class:`DataFrame` not respecting ``numeric_only`` argument when ``level`` keyword was given (:issue:`40660`)
827830
- Bug in :meth:`SeriesGroupBy.aggregate` where using a user-defined function to aggregate a ``Series`` with an object-typed :class:`Index` causes an incorrect :class:`Index` shape (issue:`40014`)

pandas/_libs/window/aggregations.pyx

+6-4
Original file line numberDiff line numberDiff line change
@@ -1485,8 +1485,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
14851485
com : float64
14861486
adjust : bool
14871487
ignore_na : bool
1488-
times : ndarray (float64 type)
1489-
halflife : float64
1488+
deltas : ndarray (float64 type)
14901489
14911490
Returns
14921491
-------
@@ -1495,7 +1494,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
14951494

14961495
cdef:
14971496
Py_ssize_t i, j, s, e, nobs, win_size, N = len(vals), M = len(start)
1498-
const float64_t[:] sub_vals
1497+
const float64_t[:] sub_deltas, sub_vals
14991498
ndarray[float64_t] sub_output, output = np.empty(N, dtype=float)
15001499
float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur
15011500
bint is_observation
@@ -1511,6 +1510,9 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
15111510
s = start[j]
15121511
e = end[j]
15131512
sub_vals = vals[s:e]
1513+
# note that len(deltas) = len(vals) - 1 and deltas[i] is to be used in
1514+
# conjunction with vals[i+1]
1515+
sub_deltas = deltas[s:e - 1]
15141516
win_size = len(sub_vals)
15151517
sub_output = np.empty(win_size, dtype=float)
15161518

@@ -1528,7 +1530,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
15281530
if weighted_avg == weighted_avg:
15291531

15301532
if is_observation or not ignore_na:
1531-
old_wt *= old_wt_factor ** deltas[i - 1]
1533+
old_wt *= old_wt_factor ** sub_deltas[i - 1]
15321534
if is_observation:
15331535

15341536
# avoid numerical errors on constant series

pandas/core/window/ewm.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,38 @@ def get_center_of_mass(
7878
return float(comass)
7979

8080

81+
def _calculate_deltas(
82+
times: str | np.ndarray | FrameOrSeries | None,
83+
halflife: float | TimedeltaConvertibleTypes | None,
84+
) -> np.ndarray:
85+
"""
86+
Return the diff of the times divided by the half-life. These values are used in
87+
the calculation of the ewm mean.
88+
89+
Parameters
90+
----------
91+
times : str, np.ndarray, Series, default None
92+
Times corresponding to the observations. Must be monotonically increasing
93+
and ``datetime64[ns]`` dtype.
94+
halflife : float, str, timedelta, optional
95+
Half-life specifying the decay
96+
97+
Returns
98+
-------
99+
np.ndarray
100+
Diff of the times divided by the half-life
101+
"""
102+
# error: Item "str" of "Union[str, ndarray, FrameOrSeries, None]" has no
103+
# attribute "view"
104+
# error: Item "None" of "Union[str, ndarray, FrameOrSeries, None]" has no
105+
# attribute "view"
106+
_times = np.asarray(
107+
times.view(np.int64), dtype=np.float64 # type: ignore[union-attr]
108+
)
109+
_halflife = float(Timedelta(halflife).value)
110+
return np.diff(_times) / _halflife
111+
112+
81113
class ExponentialMovingWindow(BaseWindow):
82114
r"""
83115
Provide exponential weighted (EW) functions.
@@ -268,15 +300,7 @@ def __init__(
268300
)
269301
if isna(self.times).any():
270302
raise ValueError("Cannot convert NaT values to integer")
271-
# error: Item "str" of "Union[str, ndarray, FrameOrSeries, None]" has no
272-
# attribute "view"
273-
# error: Item "None" of "Union[str, ndarray, FrameOrSeries, None]" has no
274-
# attribute "view"
275-
_times = np.asarray(
276-
self.times.view(np.int64), dtype=np.float64 # type: ignore[union-attr]
277-
)
278-
_halflife = float(Timedelta(self.halflife).value)
279-
self._deltas = np.diff(_times) / _halflife
303+
self._deltas = _calculate_deltas(self.times, self.halflife)
280304
# Halflife is no longer applicable when calculating COM
281305
# But allow COM to still be calculated if the user passes other decay args
282306
if common.count_not_none(self.com, self.span, self.alpha) > 0:
@@ -585,6 +609,17 @@ class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow)
585609

586610
_attributes = ExponentialMovingWindow._attributes + BaseWindowGroupby._attributes
587611

612+
def __init__(self, obj, *args, _grouper=None, **kwargs):
613+
super().__init__(obj, *args, _grouper=_grouper, **kwargs)
614+
615+
if not obj.empty and self.times is not None:
616+
# sort the times and recalculate the deltas according to the groups
617+
groupby_order = np.concatenate(list(self._grouper.indices.values()))
618+
self._deltas = _calculate_deltas(
619+
self.times.take(groupby_order), # type: ignore[union-attr]
620+
self.halflife,
621+
)
622+
588623
def _get_window_indexer(self) -> GroupbyIndexer:
589624
"""
590625
Return an indexer class that will compute the window start and end bounds
@@ -628,10 +663,7 @@ def mean(self, engine=None, engine_kwargs=None):
628663
"""
629664
if maybe_use_numba(engine):
630665
groupby_ewma_func = generate_numba_groupby_ewma_func(
631-
engine_kwargs,
632-
self._com,
633-
self.adjust,
634-
self.ignore_na,
666+
engine_kwargs, self._com, self.adjust, self.ignore_na, self._deltas
635667
)
636668
return self._apply(
637669
groupby_ewma_func,

pandas/core/window/numba_.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def generate_numba_groupby_ewma_func(
8585
com: float,
8686
adjust: bool,
8787
ignore_na: bool,
88+
deltas: np.ndarray,
8889
):
8990
"""
9091
Generate a numba jitted groupby ewma function specified by values
@@ -97,6 +98,7 @@ def generate_numba_groupby_ewma_func(
9798
com : float
9899
adjust : bool
99100
ignore_na : bool
101+
deltas : numpy.ndarray
100102
101103
Returns
102104
-------
@@ -141,7 +143,9 @@ def groupby_ewma(
141143

142144
if is_observation or not ignore_na:
143145

144-
old_wt *= old_wt_factor
146+
# note that len(deltas) = len(vals) - 1 and deltas[i] is to be
147+
# used in conjunction with vals[i+1]
148+
old_wt *= old_wt_factor ** deltas[start + j - 1]
145149
if is_observation:
146150

147151
# avoid numerical errors on constant series

pandas/tests/window/conftest.py

+26
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Series,
1414
bdate_range,
1515
notna,
16+
to_datetime,
1617
)
1718

1819

@@ -302,6 +303,31 @@ def frame():
302303
)
303304

304305

306+
@pytest.fixture
307+
def times_frame():
308+
"""Frame for testing times argument in EWM groupby."""
309+
return DataFrame(
310+
{
311+
"A": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"],
312+
"B": [0, 0, 0, 1, 1, 1, 2, 2, 2, 3],
313+
"C": to_datetime(
314+
[
315+
"2020-01-01",
316+
"2020-01-01",
317+
"2020-01-01",
318+
"2020-01-02",
319+
"2020-01-10",
320+
"2020-01-22",
321+
"2020-01-03",
322+
"2020-01-23",
323+
"2020-01-23",
324+
"2020-01-04",
325+
]
326+
),
327+
}
328+
)
329+
330+
305331
@pytest.fixture
306332
def series():
307333
"""Make mocked series as fixture."""

pandas/tests/window/test_groupby.py

+60
Original file line numberDiff line numberDiff line change
@@ -926,3 +926,63 @@ def test_pairwise_methods(self, method, expected_data):
926926

927927
expected = df.groupby("A").apply(lambda x: getattr(x.ewm(com=1.0), method)())
928928
tm.assert_frame_equal(result, expected)
929+
930+
def test_times(self, times_frame):
931+
# GH 40951
932+
halflife = "23 days"
933+
result = times_frame.groupby("A").ewm(halflife=halflife, times="C").mean()
934+
expected = DataFrame(
935+
{
936+
"B": [
937+
0.0,
938+
0.507534,
939+
1.020088,
940+
1.537661,
941+
0.0,
942+
0.567395,
943+
1.221209,
944+
0.0,
945+
0.653141,
946+
1.195003,
947+
]
948+
},
949+
index=MultiIndex.from_tuples(
950+
[
951+
("a", 0),
952+
("a", 3),
953+
("a", 6),
954+
("a", 9),
955+
("b", 1),
956+
("b", 4),
957+
("b", 7),
958+
("c", 2),
959+
("c", 5),
960+
("c", 8),
961+
],
962+
names=["A", None],
963+
),
964+
)
965+
tm.assert_frame_equal(result, expected)
966+
967+
def test_times_vs_apply(self, times_frame):
968+
# GH 40951
969+
halflife = "23 days"
970+
result = times_frame.groupby("A").ewm(halflife=halflife, times="C").mean()
971+
expected = (
972+
times_frame.groupby("A")
973+
.apply(lambda x: x.ewm(halflife=halflife, times="C").mean())
974+
.iloc[[0, 3, 6, 9, 1, 4, 7, 2, 5, 8]]
975+
.reset_index(drop=True)
976+
)
977+
tm.assert_frame_equal(result.reset_index(drop=True), expected)
978+
979+
def test_times_array(self, times_frame):
980+
# GH 40951
981+
halflife = "23 days"
982+
result = times_frame.groupby("A").ewm(halflife=halflife, times="C").mean()
983+
expected = (
984+
times_frame.groupby("A")
985+
.ewm(halflife=halflife, times=times_frame["C"].values)
986+
.mean()
987+
)
988+
tm.assert_frame_equal(result, expected)

pandas/tests/window/test_numba.py

+25
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DataFrame,
99
Series,
1010
option_context,
11+
to_datetime,
1112
)
1213
import pandas._testing as tm
1314
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
@@ -145,6 +146,30 @@ def test_cython_vs_numba(self, nogil, parallel, nopython, ignore_na, adjust):
145146

146147
tm.assert_frame_equal(result, expected)
147148

149+
def test_cython_vs_numba_times(self, nogil, parallel, nopython, ignore_na):
150+
# GH 40951
151+
halflife = "23 days"
152+
times = to_datetime(
153+
[
154+
"2020-01-01",
155+
"2020-01-01",
156+
"2020-01-02",
157+
"2020-01-10",
158+
"2020-02-23",
159+
"2020-01-03",
160+
]
161+
)
162+
df = DataFrame({"A": ["a", "b", "a", "b", "b", "a"], "B": [0, 0, 1, 1, 2, 2]})
163+
gb_ewm = df.groupby("A").ewm(
164+
halflife=halflife, adjust=True, ignore_na=ignore_na, times=times
165+
)
166+
167+
engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
168+
result = gb_ewm.mean(engine="numba", engine_kwargs=engine_kwargs)
169+
expected = gb_ewm.mean(engine="cython")
170+
171+
tm.assert_frame_equal(result, expected)
172+
148173

149174
@td.skip_if_no("numba", "0.46.0")
150175
def test_use_global_config():

0 commit comments

Comments
 (0)