|
6 | 6 | from functools import partial
|
7 | 7 | import inspect
|
8 | 8 | from textwrap import dedent
|
9 |
| -from typing import Callable, Dict, List, Optional, Set, Tuple, Union |
| 9 | +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union |
10 | 10 |
|
11 | 11 | import numpy as np
|
12 | 12 |
|
|
37 | 37 | from pandas.core.base import DataError, PandasObject, SelectionMixin, ShallowMixin
|
38 | 38 | import pandas.core.common as com
|
39 | 39 | from pandas.core.construction import extract_array
|
40 |
| -from pandas.core.indexes.api import Index, ensure_index |
| 40 | +from pandas.core.indexes.api import Index, MultiIndex, ensure_index |
41 | 41 | from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
|
42 | 42 | from pandas.core.window.common import (
|
43 | 43 | WindowGroupByMixin,
|
|
49 | 49 | from pandas.core.window.indexers import (
|
50 | 50 | BaseIndexer,
|
51 | 51 | FixedWindowIndexer,
|
| 52 | + GroupbyRollingIndexer, |
52 | 53 | VariableWindowIndexer,
|
53 | 54 | )
|
54 | 55 | from pandas.core.window.numba_ import generate_numba_apply_func
|
@@ -219,12 +220,10 @@ def _validate_get_window_bounds_signature(window: BaseIndexer) -> None:
|
219 | 220 | f"get_window_bounds"
|
220 | 221 | )
|
221 | 222 |
|
222 |
| - def _create_blocks(self): |
| 223 | + def _create_blocks(self, obj: FrameOrSeries): |
223 | 224 | """
|
224 | 225 | Split data into blocks & return conformed data.
|
225 | 226 | """
|
226 |
| - obj = self._selected_obj |
227 |
| - |
228 | 227 | # filter out the on from the object
|
229 | 228 | if self.on is not None and not isinstance(self.on, Index):
|
230 | 229 | if obj.ndim == 2:
|
@@ -320,7 +319,7 @@ def __repr__(self) -> str:
|
320 | 319 |
|
321 | 320 | def __iter__(self):
|
322 | 321 | window = self._get_window(win_type=None)
|
323 |
| - blocks, obj = self._create_blocks() |
| 322 | + blocks, obj = self._create_blocks(self._selected_obj) |
324 | 323 | index = self._get_window_indexer(window=window)
|
325 | 324 |
|
326 | 325 | start, end = index.get_window_bounds(
|
@@ -527,7 +526,7 @@ def _apply(
|
527 | 526 | win_type = self._get_win_type(kwargs)
|
528 | 527 | window = self._get_window(win_type=win_type)
|
529 | 528 |
|
530 |
| - blocks, obj = self._create_blocks() |
| 529 | + blocks, obj = self._create_blocks(self._selected_obj) |
531 | 530 | block_list = list(blocks)
|
532 | 531 | window_indexer = self._get_window_indexer(window)
|
533 | 532 |
|
@@ -1261,7 +1260,7 @@ def count(self):
|
1261 | 1260 | # implementations shouldn't end up here
|
1262 | 1261 | assert not isinstance(self.window, BaseIndexer)
|
1263 | 1262 |
|
1264 |
| - blocks, obj = self._create_blocks() |
| 1263 | + blocks, obj = self._create_blocks(self._selected_obj) |
1265 | 1264 | results = []
|
1266 | 1265 | for b in blocks:
|
1267 | 1266 | result = b.notna().astype(int)
|
@@ -2174,12 +2173,103 @@ class RollingGroupby(WindowGroupByMixin, Rolling):
|
2174 | 2173 | Provide a rolling groupby implementation.
|
2175 | 2174 | """
|
2176 | 2175 |
|
| 2176 | + def _apply( |
| 2177 | + self, |
| 2178 | + func: Callable, |
| 2179 | + center: bool, |
| 2180 | + require_min_periods: int = 0, |
| 2181 | + floor: int = 1, |
| 2182 | + is_weighted: bool = False, |
| 2183 | + name: Optional[str] = None, |
| 2184 | + use_numba_cache: bool = False, |
| 2185 | + **kwargs, |
| 2186 | + ): |
| 2187 | + result = Rolling._apply( |
| 2188 | + self, |
| 2189 | + func, |
| 2190 | + center, |
| 2191 | + require_min_periods, |
| 2192 | + floor, |
| 2193 | + is_weighted, |
| 2194 | + name, |
| 2195 | + use_numba_cache, |
| 2196 | + **kwargs, |
| 2197 | + ) |
| 2198 | + # Cannot use _wrap_outputs because we calculate the result all at once |
| 2199 | + # Compose MultiIndex result from grouping levels then rolling level |
| 2200 | + # Aggregate the MultiIndex data as tuples then the level names |
| 2201 | + grouped_object_index = self._groupby._selected_obj.index |
| 2202 | + grouped_index_name = [grouped_object_index.name] |
| 2203 | + groupby_keys = [grouping.name for grouping in self._groupby.grouper._groupings] |
| 2204 | + result_index_names = groupby_keys + grouped_index_name |
| 2205 | + |
| 2206 | + result_index_data = [] |
| 2207 | + for key, values in self._groupby.grouper.indices.items(): |
| 2208 | + for value in values: |
| 2209 | + if not is_list_like(key): |
| 2210 | + data = [key, grouped_object_index[value]] |
| 2211 | + else: |
| 2212 | + data = [*key, grouped_object_index[value]] |
| 2213 | + result_index_data.append(tuple(data)) |
| 2214 | + |
| 2215 | + result_index = MultiIndex.from_tuples( |
| 2216 | + result_index_data, names=result_index_names |
| 2217 | + ) |
| 2218 | + result.index = result_index |
| 2219 | + return result |
| 2220 | + |
2177 | 2221 | @property
|
2178 | 2222 | def _constructor(self):
|
2179 | 2223 | return Rolling
|
2180 | 2224 |
|
2181 |
| - def _gotitem(self, key, ndim, subset=None): |
| 2225 | + def _create_blocks(self, obj: FrameOrSeries): |
| 2226 | + """ |
| 2227 | + Split data into blocks & return conformed data. |
| 2228 | + """ |
| 2229 | + # Ensure the object we're rolling over is monotonically sorted relative |
| 2230 | + # to the groups |
| 2231 | + obj = obj.take(np.concatenate(list(self._groupby.grouper.indices.values()))) |
| 2232 | + return super()._create_blocks(obj) |
| 2233 | + |
| 2234 | + def _get_cython_func_type(self, func: str) -> Callable: |
| 2235 | + """ |
| 2236 | + Return the cython function type. |
| 2237 | +
|
| 2238 | + RollingGroupby needs to always use "variable" algorithms since processing |
| 2239 | + the data in group order may not be monotonic with the data which |
| 2240 | + "fixed" algorithms assume |
| 2241 | + """ |
| 2242 | + return self._get_roll_func(f"{func}_variable") |
| 2243 | + |
| 2244 | + def _get_window_indexer(self, window: int) -> GroupbyRollingIndexer: |
| 2245 | + """ |
| 2246 | + Return an indexer class that will compute the window start and end bounds |
| 2247 | +
|
| 2248 | + Parameters |
| 2249 | + ---------- |
| 2250 | + window : int |
| 2251 | + window size for FixedWindowIndexer |
2182 | 2252 |
|
| 2253 | + Returns |
| 2254 | + ------- |
| 2255 | + GroupbyRollingIndexer |
| 2256 | + """ |
| 2257 | + rolling_indexer: Union[Type[FixedWindowIndexer], Type[VariableWindowIndexer]] |
| 2258 | + if self.is_freq_type: |
| 2259 | + rolling_indexer = VariableWindowIndexer |
| 2260 | + index_array = self._groupby._selected_obj.index.asi8 |
| 2261 | + else: |
| 2262 | + rolling_indexer = FixedWindowIndexer |
| 2263 | + index_array = None |
| 2264 | + window_indexer = GroupbyRollingIndexer( |
| 2265 | + index_array=index_array, |
| 2266 | + window_size=window, |
| 2267 | + groupby_indicies=self._groupby.indices, |
| 2268 | + rolling_indexer=rolling_indexer, |
| 2269 | + ) |
| 2270 | + return window_indexer |
| 2271 | + |
| 2272 | + def _gotitem(self, key, ndim, subset=None): |
2183 | 2273 | # we are setting the index on the actual object
|
2184 | 2274 | # here so our index is carried thru to the selected obj
|
2185 | 2275 | # when we do the splitting for the groupby
|
|
0 commit comments