Skip to content

Commit 579b75a

Browse files
authored
BUG: EWM.__getitem__ raised error with times (#40164)
1 parent b50a2e2 commit 579b75a

File tree

4 files changed

+61
-35
lines changed

4 files changed

+61
-35
lines changed

doc/source/whatsnew/v1.3.0.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,8 @@ Groupby/resample/rolling
520520
- Bug in :meth:`DataFrameGroupBy.sample` where error was raised when ``weights`` was specified and the index was an :class:`Int64Index` (:issue:`39927`)
521521
- Bug in :meth:`DataFrameGroupBy.aggregate` and :meth:`.Resampler.aggregate` would sometimes raise ``SpecificationError`` when passed a dictionary and columns were missing; will now always raise a ``KeyError`` instead (:issue:`40004`)
522522
- Bug in :meth:`DataFrameGroupBy.sample` where column selection was not applied to sample result (:issue:`39928`)
523-
-
523+
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would incorrectly raise a ``ValueError`` when providing ``times`` (:issue:`40164`)
524+
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would not retain ``com``, ``span``, ``alpha`` or ``halflife`` attributes (:issue:`40164`)
524525

525526
Reshaping
526527
^^^^^^^^^

pandas/_libs/window/aggregations.pyx

+3-4
Original file line numberDiff line numberDiff line change
@@ -1476,7 +1476,7 @@ def roll_weighted_var(const float64_t[:] values, const float64_t[:] weights,
14761476

14771477
def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
14781478
int minp, float64_t com, bint adjust, bint ignore_na,
1479-
const float64_t[:] times, float64_t halflife):
1479+
const float64_t[:] deltas):
14801480
"""
14811481
Compute exponentially-weighted moving average using center-of-mass.
14821482
@@ -1501,7 +1501,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
15011501
Py_ssize_t i, j, s, e, nobs, win_size, N = len(vals), M = len(start)
15021502
const float64_t[:] sub_vals
15031503
ndarray[float64_t] sub_output, output = np.empty(N, dtype=float)
1504-
float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur, delta
1504+
float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur
15051505
bint is_observation
15061506

15071507
if N == 0:
@@ -1532,8 +1532,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end,
15321532
if weighted_avg == weighted_avg:
15331533

15341534
if is_observation or not ignore_na:
1535-
delta = times[i] - times[i - 1]
1536-
old_wt *= old_wt_factor ** (delta / halflife)
1535+
old_wt *= old_wt_factor ** deltas[i - 1]
15371536
if is_observation:
15381537

15391538
# avoid numerical errors on constant series

pandas/core/window/ewm.py

+33-30
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,13 @@ class ExponentialMovingWindow(BaseWindow):
215215

216216
_attributes = [
217217
"com",
218+
"span",
219+
"halflife",
220+
"alpha",
218221
"min_periods",
219222
"adjust",
220223
"ignore_na",
221224
"axis",
222-
"halflife",
223225
"times",
224226
]
225227

@@ -245,38 +247,48 @@ def __init__(
245247
method="single",
246248
axis=axis,
247249
)
250+
self.com = com
251+
self.span = span
252+
self.halflife = halflife
253+
self.alpha = alpha
248254
self.adjust = adjust
249255
self.ignore_na = ignore_na
250-
if times is not None:
256+
self.times = times
257+
if self.times is not None:
251258
if isinstance(times, str):
252-
times = self._selected_obj[times]
253-
if not is_datetime64_ns_dtype(times):
259+
self.times = self._selected_obj[times]
260+
if not is_datetime64_ns_dtype(self.times):
254261
raise ValueError("times must be datetime64[ns] dtype.")
255-
if len(times) != len(obj):
262+
if len(self.times) != len(obj):
256263
raise ValueError("times must be the same length as the object.")
257264
if not isinstance(halflife, (str, datetime.timedelta)):
258265
raise ValueError(
259266
"halflife must be a string or datetime.timedelta object"
260267
)
261-
if isna(times).any():
268+
if isna(self.times).any():
262269
raise ValueError("Cannot convert NaT values to integer")
263-
self.times = np.asarray(times.view(np.int64))
264-
self.halflife = Timedelta(halflife).value
270+
_times = np.asarray(self.times.view(np.int64), dtype=np.float64)
271+
_halflife = float(Timedelta(self.halflife).value)
272+
self._deltas = np.diff(_times) / _halflife
265273
# Halflife is no longer applicable when calculating COM
266274
# But allow COM to still be calculated if the user passes other decay args
267-
if common.count_not_none(com, span, alpha) > 0:
268-
self.com = get_center_of_mass(com, span, None, alpha)
275+
if common.count_not_none(self.com, self.span, self.alpha) > 0:
276+
self._com = get_center_of_mass(self.com, self.span, None, self.alpha)
269277
else:
270-
self.com = 0.0
278+
self._com = 1.0
271279
else:
272-
if halflife is not None and isinstance(halflife, (str, datetime.timedelta)):
280+
if self.halflife is not None and isinstance(
281+
self.halflife, (str, datetime.timedelta)
282+
):
273283
raise ValueError(
274284
"halflife can only be a timedelta convertible argument if "
275285
"times is not None."
276286
)
277-
self.times = None
278-
self.halflife = None
279-
self.com = get_center_of_mass(com, span, halflife, alpha)
287+
# Without times, points are equally spaced
288+
self._deltas = np.ones(max(len(self.obj) - 1, 0), dtype=np.float64)
289+
self._com = get_center_of_mass(
290+
self.com, self.span, self.halflife, self.alpha
291+
)
280292

281293
def _get_window_indexer(self) -> BaseIndexer:
282294
"""
@@ -334,22 +346,13 @@ def aggregate(self, func, *args, **kwargs):
334346
)
335347
def mean(self, *args, **kwargs):
336348
nv.validate_window_func("mean", args, kwargs)
337-
if self.times is not None:
338-
com = 1.0
339-
times = self.times.astype(np.float64)
340-
halflife = float(self.halflife)
341-
else:
342-
com = self.com
343-
times = np.arange(len(self.obj), dtype=np.float64)
344-
halflife = 1.0
345349
window_func = window_aggregations.ewma
346350
window_func = partial(
347351
window_func,
348-
com=com,
352+
com=self._com,
349353
adjust=self.adjust,
350354
ignore_na=self.ignore_na,
351-
times=times,
352-
halflife=halflife,
355+
deltas=self._deltas,
353356
)
354357
return self._apply(window_func)
355358

@@ -411,7 +414,7 @@ def var(self, bias: bool = False, *args, **kwargs):
411414
window_func = window_aggregations.ewmcov
412415
window_func = partial(
413416
window_func,
414-
com=self.com,
417+
com=self._com,
415418
adjust=self.adjust,
416419
ignore_na=self.ignore_na,
417420
bias=bias,
@@ -480,7 +483,7 @@ def cov_func(x, y):
480483
end,
481484
self.min_periods,
482485
y_array,
483-
self.com,
486+
self._com,
484487
self.adjust,
485488
self.ignore_na,
486489
bias,
@@ -546,7 +549,7 @@ def _cov(X, Y):
546549
end,
547550
self.min_periods,
548551
Y,
549-
self.com,
552+
self._com,
550553
self.adjust,
551554
self.ignore_na,
552555
1,
@@ -613,7 +616,7 @@ def mean(self, engine=None, engine_kwargs=None):
613616
if maybe_use_numba(engine):
614617
groupby_ewma_func = generate_numba_groupby_ewma_func(
615618
engine_kwargs,
616-
self.com,
619+
self._com,
617620
self.adjust,
618621
self.ignore_na,
619622
)

pandas/tests/window/test_ewm.py

+23
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,29 @@ def test_ewm_with_nat_raises(halflife_with_times):
142142
ser.ewm(com=0.1, halflife=halflife_with_times, times=times)
143143

144144

145+
def test_ewm_with_times_getitem(halflife_with_times):
146+
# GH 40164
147+
halflife = halflife_with_times
148+
data = np.arange(10.0)
149+
data[::2] = np.nan
150+
times = date_range("2000", freq="D", periods=10)
151+
df = DataFrame({"A": data, "B": data})
152+
result = df.ewm(halflife=halflife, times=times)["A"].mean()
153+
expected = df.ewm(halflife=1.0)["A"].mean()
154+
tm.assert_series_equal(result, expected)
155+
156+
157+
@pytest.mark.parametrize("arg", ["com", "halflife", "span", "alpha"])
158+
def test_ewm_getitem_attributes_retained(arg, adjust, ignore_na):
159+
# GH 40164
160+
kwargs = {arg: 1, "adjust": adjust, "ignore_na": ignore_na}
161+
ewm = DataFrame({"A": range(1), "B": range(1)}).ewm(**kwargs)
162+
expected = {attr: getattr(ewm, attr) for attr in ewm._attributes}
163+
ewm_slice = ewm["A"]
164+
result = {attr: getattr(ewm, attr) for attr in ewm_slice._attributes}
165+
assert result == expected
166+
167+
145168
def test_ewm_vol_deprecated():
146169
ser = Series(range(1))
147170
with tm.assert_produces_warning(FutureWarning):

0 commit comments

Comments
 (0)