diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 5c61f259a4202..9bb9f0c7a467a 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -520,7 +520,8 @@ Groupby/resample/rolling - Bug in :meth:`DataFrameGroupBy.sample` where error was raised when ``weights`` was specified and the index was an :class:`Int64Index` (:issue:`39927`) - Bug in :meth:`DataFrameGroupBy.aggregate` and :meth:`.Resampler.aggregate` would sometimes raise ``SpecificationError`` when passed a dictionary and columns were missing; will now always raise a ``KeyError`` instead (:issue:`40004`) - Bug in :meth:`DataFrameGroupBy.sample` where column selection was not applied to sample result (:issue:`39928`) -- +- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would incorrectly raise a ``ValueError`` when providing ``times`` (:issue:`40164`) +- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would not retain ``com``, ``span``, ``alpha`` or ``halflife`` attributes (:issue:`40164`) Reshaping ^^^^^^^^^ diff --git a/pandas/_libs/window/aggregations.pyx b/pandas/_libs/window/aggregations.pyx index fcd2acc3e025a..efacfad40ef82 100644 --- a/pandas/_libs/window/aggregations.pyx +++ b/pandas/_libs/window/aggregations.pyx @@ -1476,7 +1476,7 @@ def roll_weighted_var(const float64_t[:] values, const float64_t[:] weights, def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end, int minp, float64_t com, bint adjust, bint ignore_na, - const float64_t[:] times, float64_t halflife): + const float64_t[:] deltas): """ Compute exponentially-weighted moving average using center-of-mass. @@ -1501,7 +1501,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end, Py_ssize_t i, j, s, e, nobs, win_size, N = len(vals), M = len(start) const float64_t[:] sub_vals ndarray[float64_t] sub_output, output = np.empty(N, dtype=float) - float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur, delta + float64_t alpha, old_wt_factor, new_wt, weighted_avg, old_wt, cur bint is_observation if N == 0: @@ -1532,8 +1532,7 @@ def ewma(const float64_t[:] vals, const int64_t[:] start, const int64_t[:] end, if weighted_avg == weighted_avg: if is_observation or not ignore_na: - delta = times[i] - times[i - 1] - old_wt *= old_wt_factor ** (delta / halflife) + old_wt *= old_wt_factor ** deltas[i - 1] if is_observation: # avoid numerical errors on constant series diff --git a/pandas/core/window/ewm.py b/pandas/core/window/ewm.py index 696c4af27e3f2..5a71db82f26e4 100644 --- a/pandas/core/window/ewm.py +++ b/pandas/core/window/ewm.py @@ -215,11 +215,13 @@ class ExponentialMovingWindow(BaseWindow): _attributes = [ "com", + "span", + "halflife", + "alpha", "min_periods", "adjust", "ignore_na", "axis", - "halflife", "times", ] @@ -245,38 +247,48 @@ def __init__( method="single", axis=axis, ) + self.com = com + self.span = span + self.halflife = halflife + self.alpha = alpha self.adjust = adjust self.ignore_na = ignore_na - if times is not None: + self.times = times + if self.times is not None: if isinstance(times, str): - times = self._selected_obj[times] - if not is_datetime64_ns_dtype(times): + self.times = self._selected_obj[times] + if not is_datetime64_ns_dtype(self.times): raise ValueError("times must be datetime64[ns] dtype.") - if len(times) != len(obj): + if len(self.times) != len(obj): raise ValueError("times must be the same length as the object.") if not isinstance(halflife, (str, datetime.timedelta)): raise ValueError( "halflife must be a string or datetime.timedelta object" ) - if isna(times).any(): + if isna(self.times).any(): raise ValueError("Cannot convert NaT values to integer") - self.times = np.asarray(times.view(np.int64)) - self.halflife = Timedelta(halflife).value + _times = np.asarray(self.times.view(np.int64), dtype=np.float64) + _halflife = float(Timedelta(self.halflife).value) + self._deltas = np.diff(_times) / _halflife # Halflife is no longer applicable when calculating COM # But allow COM to still be calculated if the user passes other decay args - if common.count_not_none(com, span, alpha) > 0: - self.com = get_center_of_mass(com, span, None, alpha) + if common.count_not_none(self.com, self.span, self.alpha) > 0: + self._com = get_center_of_mass(self.com, self.span, None, self.alpha) else: - self.com = 0.0 + self._com = 1.0 else: - if halflife is not None and isinstance(halflife, (str, datetime.timedelta)): + if self.halflife is not None and isinstance( + self.halflife, (str, datetime.timedelta) + ): raise ValueError( "halflife can only be a timedelta convertible argument if " "times is not None." ) - self.times = None - self.halflife = None - self.com = get_center_of_mass(com, span, halflife, alpha) + # Without times, points are equally spaced + self._deltas = np.ones(max(len(self.obj) - 1, 0), dtype=np.float64) + self._com = get_center_of_mass( + self.com, self.span, self.halflife, self.alpha + ) def _get_window_indexer(self) -> BaseIndexer: """ @@ -334,22 +346,13 @@ def aggregate(self, func, *args, **kwargs): ) def mean(self, *args, **kwargs): nv.validate_window_func("mean", args, kwargs) - if self.times is not None: - com = 1.0 - times = self.times.astype(np.float64) - halflife = float(self.halflife) - else: - com = self.com - times = np.arange(len(self.obj), dtype=np.float64) - halflife = 1.0 window_func = window_aggregations.ewma window_func = partial( window_func, - com=com, + com=self._com, adjust=self.adjust, ignore_na=self.ignore_na, - times=times, - halflife=halflife, + deltas=self._deltas, ) return self._apply(window_func) @@ -411,7 +414,7 @@ def var(self, bias: bool = False, *args, **kwargs): window_func = window_aggregations.ewmcov window_func = partial( window_func, - com=self.com, + com=self._com, adjust=self.adjust, ignore_na=self.ignore_na, bias=bias, @@ -480,7 +483,7 @@ def cov_func(x, y): end, self.min_periods, y_array, - self.com, + self._com, self.adjust, self.ignore_na, bias, @@ -546,7 +549,7 @@ def _cov(X, Y): end, self.min_periods, Y, - self.com, + self._com, self.adjust, self.ignore_na, 1, @@ -613,7 +616,7 @@ def mean(self, engine=None, engine_kwargs=None): if maybe_use_numba(engine): groupby_ewma_func = generate_numba_groupby_ewma_func( engine_kwargs, - self.com, + self._com, self.adjust, self.ignore_na, ) diff --git a/pandas/tests/window/test_ewm.py b/pandas/tests/window/test_ewm.py index fbd7a36a75bf0..3e823844c7f56 100644 --- a/pandas/tests/window/test_ewm.py +++ b/pandas/tests/window/test_ewm.py @@ -142,6 +142,29 @@ def test_ewm_with_nat_raises(halflife_with_times): ser.ewm(com=0.1, halflife=halflife_with_times, times=times) +def test_ewm_with_times_getitem(halflife_with_times): + # GH 40164 + halflife = halflife_with_times + data = np.arange(10.0) + data[::2] = np.nan + times = date_range("2000", freq="D", periods=10) + df = DataFrame({"A": data, "B": data}) + result = df.ewm(halflife=halflife, times=times)["A"].mean() + expected = df.ewm(halflife=1.0)["A"].mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("arg", ["com", "halflife", "span", "alpha"]) +def test_ewm_getitem_attributes_retained(arg, adjust, ignore_na): + # GH 40164 + kwargs = {arg: 1, "adjust": adjust, "ignore_na": ignore_na} + ewm = DataFrame({"A": range(1), "B": range(1)}).ewm(**kwargs) + expected = {attr: getattr(ewm, attr) for attr in ewm._attributes} + ewm_slice = ewm["A"] + result = {attr: getattr(ewm, attr) for attr in ewm_slice._attributes} + assert result == expected + + def test_ewm_vol_deprecated(): ser = Series(range(1)) with tm.assert_produces_warning(FutureWarning):