diff --git a/doc/source/whatsnew/v1.1.1.rst b/doc/source/whatsnew/v1.1.1.rst index 415f9e508feb8..cdc244ca193b4 100644 --- a/doc/source/whatsnew/v1.1.1.rst +++ b/doc/source/whatsnew/v1.1.1.rst @@ -22,6 +22,7 @@ Fixed regressions - Fixed regression in :meth:`DataFrame.shift` with ``axis=1`` and heterogeneous dtypes (:issue:`35488`) - Fixed regression in ``.groupby(..).rolling(..)`` where a segfault would occur with ``center=True`` and an odd number of values (:issue:`35552`) - Fixed regression in :meth:`DataFrame.apply` where functions that altered the input in-place only operated on a single row (:issue:`35462`) +- Fixed regression in ``.groupby(..).rolling(..)`` where a custom ``BaseIndexer`` would be ignored (:issue:`35557`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/window/indexers.py b/pandas/core/window/indexers.py index bc36bdca982e8..7cbe34cdebf9f 100644 --- a/pandas/core/window/indexers.py +++ b/pandas/core/window/indexers.py @@ -1,6 +1,6 @@ """Indexer objects for computing start/end window bounds for rolling operations""" from datetime import timedelta -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, Optional, Tuple, Type import numpy as np @@ -265,7 +265,8 @@ def __init__( index_array: Optional[np.ndarray], window_size: int, groupby_indicies: Dict, - rolling_indexer: Union[Type[FixedWindowIndexer], Type[VariableWindowIndexer]], + rolling_indexer: Type[BaseIndexer], + indexer_kwargs: Optional[Dict], **kwargs, ): """ @@ -276,7 +277,10 @@ def __init__( """ self.groupby_indicies = groupby_indicies self.rolling_indexer = rolling_indexer - super().__init__(index_array, window_size, **kwargs) + self.indexer_kwargs = indexer_kwargs or {} + super().__init__( + index_array, self.indexer_kwargs.pop("window_size", window_size), **kwargs + ) @Appender(get_window_bounds_doc) def get_window_bounds( @@ -298,7 +302,9 @@ def get_window_bounds( else: index_array = self.index_array indexer = self.rolling_indexer( - index_array=index_array, window_size=self.window_size, + index_array=index_array, + window_size=self.window_size, + **self.indexer_kwargs, ) start, end = indexer.get_window_bounds( len(indicies), min_periods, center, closed diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index 7347d5686aabc..0306d4de2fc73 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -145,7 +145,7 @@ class _Window(PandasObject, ShallowMixin, SelectionMixin): def __init__( self, - obj, + obj: FrameOrSeries, window=None, min_periods: Optional[int] = None, center: bool = False, @@ -2271,10 +2271,16 @@ def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer: ------- GroupbyRollingIndexer """ - rolling_indexer: Union[Type[FixedWindowIndexer], Type[VariableWindowIndexer]] - if self.is_freq_type: + rolling_indexer: Type[BaseIndexer] + indexer_kwargs: Optional[Dict] = None + index_array = self.obj.index.asi8 + if isinstance(self.window, BaseIndexer): + rolling_indexer = type(self.window) + indexer_kwargs = self.window.__dict__ + # We'll be using the index of each group later + indexer_kwargs.pop("index_array", None) + elif self.is_freq_type: rolling_indexer = VariableWindowIndexer - index_array = self.obj.index.asi8 else: rolling_indexer = FixedWindowIndexer index_array = None @@ -2283,6 +2289,7 @@ def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer: window_size=window, groupby_indicies=self._groupby.indices, rolling_indexer=rolling_indexer, + indexer_kwargs=indexer_kwargs, ) return window_indexer diff --git a/pandas/tests/window/test_grouper.py b/pandas/tests/window/test_grouper.py index e1dcac06c39cc..a9590c7e1233a 100644 --- a/pandas/tests/window/test_grouper.py +++ b/pandas/tests/window/test_grouper.py @@ -305,6 +305,29 @@ def test_groupby_subselect_rolling(self): ) tm.assert_series_equal(result, expected) + def test_groupby_rolling_custom_indexer(self): + # GH 35557 + class SimpleIndexer(pd.api.indexers.BaseIndexer): + def get_window_bounds( + self, num_values=0, min_periods=None, center=None, closed=None + ): + min_periods = self.window_size if min_periods is None else 0 + end = np.arange(num_values, dtype=np.int64) + 1 + start = end.copy() - self.window_size + start[start < 0] = min_periods + return start, end + + df = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 4.0, 5.0] * 3}, index=[0] * 5 + [1] * 5 + [2] * 5 + ) + result = ( + df.groupby(df.index) + .rolling(SimpleIndexer(window_size=3), min_periods=1) + .sum() + ) + expected = df.groupby(df.index).rolling(window=3, min_periods=1).sum() + tm.assert_frame_equal(result, expected) + def test_groupby_rolling_subset_with_closed(self): # GH 35549 df = pd.DataFrame(