Skip to content

Commit b3b8128

Browse files
mroeschkemeeseeksmachine
authored andcommitted
Backport PR pandas-dev#35647: BUG: Support custom BaseIndexers in groupby.rolling
1 parent ac40043 commit b3b8128

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

doc/source/whatsnew/v1.1.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Fixed regressions
2222
- Fixed regression in :meth:`DataFrame.shift` with ``axis=1`` and heterogeneous dtypes (:issue:`35488`)
2323
- Fixed regression in ``.groupby(..).rolling(..)`` where a segfault would occur with ``center=True`` and an odd number of values (:issue:`35552`)
2424
- Fixed regression in :meth:`DataFrame.apply` where functions that altered the input in-place only operated on a single row (:issue:`35462`)
25+
- Fixed regression in ``.groupby(..).rolling(..)`` where a custom ``BaseIndexer`` would be ignored (:issue:`35557`)
2526

2627
.. ---------------------------------------------------------------------------
2728

pandas/core/window/indexers.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Indexer objects for computing start/end window bounds for rolling operations"""
22
from datetime import timedelta
3-
from typing import Dict, Optional, Tuple, Type, Union
3+
from typing import Dict, Optional, Tuple, Type
44

55
import numpy as np
66

@@ -265,7 +265,8 @@ def __init__(
265265
index_array: Optional[np.ndarray],
266266
window_size: int,
267267
groupby_indicies: Dict,
268-
rolling_indexer: Union[Type[FixedWindowIndexer], Type[VariableWindowIndexer]],
268+
rolling_indexer: Type[BaseIndexer],
269+
indexer_kwargs: Optional[Dict],
269270
**kwargs,
270271
):
271272
"""
@@ -276,7 +277,10 @@ def __init__(
276277
"""
277278
self.groupby_indicies = groupby_indicies
278279
self.rolling_indexer = rolling_indexer
279-
super().__init__(index_array, window_size, **kwargs)
280+
self.indexer_kwargs = indexer_kwargs or {}
281+
super().__init__(
282+
index_array, self.indexer_kwargs.pop("window_size", window_size), **kwargs
283+
)
280284

281285
@Appender(get_window_bounds_doc)
282286
def get_window_bounds(
@@ -298,7 +302,9 @@ def get_window_bounds(
298302
else:
299303
index_array = self.index_array
300304
indexer = self.rolling_indexer(
301-
index_array=index_array, window_size=self.window_size,
305+
index_array=index_array,
306+
window_size=self.window_size,
307+
**self.indexer_kwargs,
302308
)
303309
start, end = indexer.get_window_bounds(
304310
len(indicies), min_periods, center, closed

pandas/core/window/rolling.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class _Window(PandasObject, ShallowMixin, SelectionMixin):
145145

146146
def __init__(
147147
self,
148-
obj,
148+
obj: FrameOrSeries,
149149
window=None,
150150
min_periods: Optional[int] = None,
151151
center: bool = False,
@@ -2255,10 +2255,16 @@ def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer:
22552255
-------
22562256
GroupbyRollingIndexer
22572257
"""
2258-
rolling_indexer: Union[Type[FixedWindowIndexer], Type[VariableWindowIndexer]]
2259-
if self.is_freq_type:
2258+
rolling_indexer: Type[BaseIndexer]
2259+
indexer_kwargs: Optional[Dict] = None
2260+
index_array = self.obj.index.asi8
2261+
if isinstance(self.window, BaseIndexer):
2262+
rolling_indexer = type(self.window)
2263+
indexer_kwargs = self.window.__dict__
2264+
# We'll be using the index of each group later
2265+
indexer_kwargs.pop("index_array", None)
2266+
elif self.is_freq_type:
22602267
rolling_indexer = VariableWindowIndexer
2261-
index_array = self.obj.index.asi8
22622268
else:
22632269
rolling_indexer = FixedWindowIndexer
22642270
index_array = None
@@ -2267,6 +2273,7 @@ def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer:
22672273
window_size=window,
22682274
groupby_indicies=self._groupby.indices,
22692275
rolling_indexer=rolling_indexer,
2276+
indexer_kwargs=indexer_kwargs,
22702277
)
22712278
return window_indexer
22722279

pandas/tests/window/test_grouper.py

+23
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,29 @@ def test_groupby_subselect_rolling(self):
305305
)
306306
tm.assert_series_equal(result, expected)
307307

308+
def test_groupby_rolling_custom_indexer(self):
309+
# GH 35557
310+
class SimpleIndexer(pd.api.indexers.BaseIndexer):
311+
def get_window_bounds(
312+
self, num_values=0, min_periods=None, center=None, closed=None
313+
):
314+
min_periods = self.window_size if min_periods is None else 0
315+
end = np.arange(num_values, dtype=np.int64) + 1
316+
start = end.copy() - self.window_size
317+
start[start < 0] = min_periods
318+
return start, end
319+
320+
df = pd.DataFrame(
321+
{"a": [1.0, 2.0, 3.0, 4.0, 5.0] * 3}, index=[0] * 5 + [1] * 5 + [2] * 5
322+
)
323+
result = (
324+
df.groupby(df.index)
325+
.rolling(SimpleIndexer(window_size=3), min_periods=1)
326+
.sum()
327+
)
328+
expected = df.groupby(df.index).rolling(window=3, min_periods=1).sum()
329+
tm.assert_frame_equal(result, expected)
330+
308331
def test_groupby_rolling_subset_with_closed(self):
309332
# GH 35549
310333
df = pd.DataFrame(

0 commit comments

Comments
 (0)