-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
API: reimplement FixedWindowIndexer.get_window_bounds to fix groupby bug #36132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
69f084f
71830c8
a449d9b
476fe83
9dfd9f3
f025600
5d902fd
59fcd3e
cdecf34
4e8f844
6e66a49
e7fb384
3649ca2
00cc1dc
3de7fcc
f779321
a817f87
d72812d
96c6959
daacae7
52a8a6b
950018c
0798c70
f413ec8
70679be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,24 +72,6 @@ | |
from pandas.core.internals import Block # noqa:F401 | ||
|
||
|
||
def calculate_center_offset(window) -> int: | ||
""" | ||
Calculate an offset necessary to have the window label to be centered. | ||
|
||
Parameters | ||
---------- | ||
window: ndarray or int | ||
window weights or window | ||
|
||
Returns | ||
------- | ||
int | ||
""" | ||
if not is_integer(window): | ||
window = len(window) | ||
return int((window - 1) / 2.0) | ||
|
||
|
||
def calculate_min_periods( | ||
window: int, | ||
min_periods: Optional[int], | ||
|
@@ -417,18 +399,44 @@ def _insert_on_column(self, result: "DataFrame", obj: "DataFrame"): | |
# insert at the end | ||
result[name] = extra_col | ||
|
||
def _center_window(self, result: np.ndarray, window) -> np.ndarray: | ||
def calculate_center_offset(self, window, center: bool) -> int: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls type window: Union[np.ndarray, int] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing that causes an error in typing validation:
|
||
""" | ||
Calculate an offset necessary to have the window label to be centered. | ||
|
||
Parameters | ||
---------- | ||
window : ndarray or int | ||
window weights or window | ||
center : bool | ||
Set the labels at the center of the window. | ||
|
||
justinessert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Returns | ||
------- | ||
int | ||
""" | ||
if not center: | ||
return 0 | ||
|
||
if self.is_freq_type or isinstance(self.window, BaseIndexer): | ||
return 0 | ||
|
||
if not is_integer(window): | ||
window = len(window) | ||
return int((window - 1) / 2.0) | ||
|
||
def _center_window(self, result: np.ndarray, window, center) -> np.ndarray: | ||
""" | ||
Center the result in the window. | ||
""" | ||
if self.axis > result.ndim - 1: | ||
raise ValueError("Requested axis is larger then no. of argument dimensions") | ||
|
||
offset = calculate_center_offset(window) | ||
offset = self.calculate_center_offset(window, center) | ||
if offset > 0: | ||
lead_indexer = [slice(None)] * result.ndim | ||
lead_indexer[self.axis] = slice(offset, None) | ||
result = np.copy(result[tuple(lead_indexer)]) | ||
|
||
return result | ||
|
||
def _get_roll_func(self, func_name: str) -> Callable: | ||
|
@@ -524,6 +532,7 @@ def _apply( | |
is_weighted: bool = False, | ||
name: Optional[str] = None, | ||
use_numba_cache: bool = False, | ||
skip_offset: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -543,6 +552,8 @@ def _apply( | |
use_numba_cache : bool | ||
whether to cache a numba compiled function. Only available for numba | ||
enabled methods (so far only apply) | ||
skip_offset : bool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the point of an addtional parameter here? this makes it really hard to understand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @justinessert can you address |
||
whether to skip offsetting x | ||
**kwargs | ||
additional arguments for rolling function and window function | ||
|
||
|
@@ -560,7 +571,11 @@ def homogeneous_func(values: np.ndarray): | |
if values.size == 0: | ||
return values.copy() | ||
|
||
offset = calculate_center_offset(window) if center else 0 | ||
if skip_offset: | ||
offset = 0 | ||
else: | ||
offset = self.calculate_center_offset(window, center) | ||
|
||
additional_nans = np.array([np.nan] * offset) | ||
|
||
if not is_weighted: | ||
|
@@ -603,8 +618,8 @@ def calc(x): | |
if use_numba_cache: | ||
NUMBA_FUNC_CACHE[(kwargs["original_func"], "rolling_apply")] = func | ||
|
||
if center: | ||
result = self._center_window(result, window) | ||
if not skip_offset: | ||
result = self._center_window(result, window, center) | ||
|
||
return result | ||
|
||
|
@@ -1189,7 +1204,7 @@ def sum(self, *args, **kwargs): | |
window_func = self._get_roll_func("roll_weighted_sum") | ||
window_func = get_weighted_roll_func(window_func) | ||
return self._apply( | ||
window_func, center=self.center, is_weighted=True, name="sum", **kwargs | ||
window_func, center=self.center, is_weighted=True, name="sum", **kwargs, | ||
) | ||
|
||
@Substitution(name="window") | ||
|
@@ -1210,7 +1225,7 @@ def var(self, ddof=1, *args, **kwargs): | |
window_func = get_weighted_roll_func(window_func) | ||
kwargs.pop("name", None) | ||
return self._apply( | ||
window_func, center=self.center, is_weighted=True, name="var", **kwargs | ||
window_func, center=self.center, is_weighted=True, name="var", **kwargs, | ||
) | ||
|
||
@Substitution(name="window", versionadded="\n.. versionadded:: 1.0.0\n") | ||
|
@@ -1388,7 +1403,8 @@ def apply( | |
# Cython apply functions handle center, so don't need to use | ||
# _apply's center handling | ||
window = self._get_window() | ||
offset = calculate_center_offset(window) if self.center else 0 | ||
|
||
offset = self.calculate_center_offset(window, self.center) | ||
apply_func = self._generate_cython_apply_func( | ||
args, kwargs, raw, offset, func | ||
) | ||
|
@@ -1406,19 +1422,17 @@ def apply( | |
raw=raw, | ||
original_func=func, | ||
args=args, | ||
skip_offset=True, | ||
kwargs=kwargs, | ||
) | ||
|
||
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): | ||
from pandas import Series | ||
|
||
cython_func = self._get_cython_func_type("roll_generic") | ||
|
||
window_func = partial( | ||
self._get_cython_func_type("roll_generic"), | ||
args=args, | ||
kwargs=kwargs, | ||
raw=raw, | ||
offset=offset, | ||
func=func, | ||
cython_func, args=args, kwargs=kwargs, raw=raw, offset=offset, func=func, | ||
) | ||
|
||
def apply_func(values, begin, end, min_periods, raw=raw): | ||
|
@@ -1433,7 +1447,7 @@ def sum(self, *args, **kwargs): | |
window_func = self._get_cython_func_type("roll_sum") | ||
kwargs.pop("floor", None) | ||
return self._apply( | ||
window_func, center=self.center, floor=0, name="sum", **kwargs | ||
window_func, center=self.center, floor=0, name="sum", **kwargs, | ||
) | ||
|
||
_shared_docs["max"] = dedent( | ||
|
@@ -1540,7 +1554,9 @@ def median(self, **kwargs): | |
window_func = self._get_roll_func("roll_median_c") | ||
# GH 32865. Move max window size calculation to | ||
# the median function implementation | ||
return self._apply(window_func, center=self.center, name="median", **kwargs) | ||
return self._apply( | ||
window_func, center=self.center, name="median", skip_offset=True, **kwargs | ||
) | ||
|
||
def std(self, ddof=1, *args, **kwargs): | ||
nv.validate_window_func("std", args, kwargs) | ||
|
@@ -1563,7 +1579,8 @@ def zsqrt_func(values, begin, end, min_periods): | |
def var(self, ddof=1, *args, **kwargs): | ||
nv.validate_window_func("var", args, kwargs) | ||
kwargs.pop("require_min_periods", None) | ||
window_func = partial(self._get_cython_func_type("roll_var"), ddof=ddof) | ||
cython_func = self._get_cython_func_type("roll_var") | ||
window_func = partial(cython_func, ddof=ddof) | ||
# ddof passed again for compat with groupby.rolling | ||
return self._apply( | ||
window_func, | ||
|
@@ -1696,20 +1713,29 @@ def kurt(self, **kwargs): | |
def quantile(self, quantile, interpolation="linear", **kwargs): | ||
if quantile == 1.0: | ||
window_func = self._get_cython_func_type("roll_max") | ||
skip_offset = False | ||
elif quantile == 0.0: | ||
window_func = self._get_cython_func_type("roll_min") | ||
skip_offset = False | ||
else: | ||
window_func = partial( | ||
self._get_roll_func("roll_quantile"), | ||
win=self._get_window(), | ||
quantile=quantile, | ||
interpolation=interpolation, | ||
) | ||
skip_offset = True | ||
|
||
# Pass through for groupby.rolling | ||
kwargs["quantile"] = quantile | ||
kwargs["interpolation"] = interpolation | ||
return self._apply(window_func, center=self.center, name="quantile", **kwargs) | ||
return self._apply( | ||
window_func, | ||
center=self.center, | ||
name="quantile", | ||
skip_offset=skip_offset, | ||
**kwargs, | ||
) | ||
|
||
_shared_docs[ | ||
"cov" | ||
|
@@ -2189,6 +2215,7 @@ def _apply( | |
is_weighted: bool = False, | ||
name: Optional[str] = None, | ||
use_numba_cache: bool = False, | ||
skip_offset: bool = True, | ||
**kwargs, | ||
): | ||
result = Rolling._apply( | ||
|
@@ -2200,6 +2227,7 @@ def _apply( | |
is_weighted, | ||
name, | ||
use_numba_cache, | ||
skip_offset, | ||
**kwargs, | ||
) | ||
# Cannot use _wrap_outputs because we calculate the result all at once | ||
|
@@ -2243,6 +2271,31 @@ def _create_data(self, obj: FrameOrSeries) -> FrameOrSeries: | |
obj = obj.take(groupby_order) | ||
return super()._create_data(obj) | ||
|
||
def calculate_center_offset(self, window, center: bool) -> int: | ||
""" | ||
Calculate an offset necessary to have the window label to be centered. | ||
|
||
Parameters | ||
---------- | ||
window : ndarray or int | ||
window weights or window | ||
center : bool | ||
Set the labels at the center of the window. | ||
|
||
Returns | ||
------- | ||
int | ||
""" | ||
if not center or not self.win_type: | ||
return 0 | ||
|
||
if self.is_freq_type or isinstance(self.window, BaseIndexer): | ||
return 0 | ||
|
||
if not is_integer(window): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we not just make this a free function? i get that you are passing win_type here, but that could easily be passed in as an arg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I again tried making this a free function and confirmed that the two classes require different functionality (whether or not to include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can just add an optional 3rd kwarg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did try that and it did not work. The issue is that in one class win_type should be completely ignored whereas the other class needs to use it. But both classes share the So I can't just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok see my comment above; if this is now a method, then you can simply use is_freq_type and this PR is a lot simpler. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jreback thanks for the suggestion, I changed the code to use Nonetheless, I do think this is a simpler approach than I had previously, so thanks for the recommendation! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jreback Are you cool with this implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the skip_offset parameter. why is it not sufficient to just make a property on the class itself (e.g. you can make another preoprty / method if needed, similr to is_ffreq_type). passing parameters around like this is really hard to understand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @justinessert I think same as above, can you explain why skip_offset arg is needed |
||
window = len(window) | ||
return int((window - 1) / 2.0) | ||
|
||
def _get_cython_func_type(self, func: str) -> Callable: | ||
""" | ||
Return the cython function type. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,53 @@ def test_rolling_apply_consistency( | |
tm.assert_equal(rolling_f_result, rolling_apply_f_result) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"window,min_periods,center", list(_rolling_consistency_cases()) | ||
) | ||
def test_rolling_groupby(base_functions, window, min_periods, center): | ||
justinessert marked this conversation as resolved.
Show resolved
Hide resolved
|
||
base_df = DataFrame({"group": "A", "data": randn(20)}) | ||
|
||
b_df = base_df.copy() | ||
b_df["group"] = "B" | ||
|
||
grp_df = pd.concat([base_df, b_df]).groupby("group") | ||
|
||
for (f, require_min_periods, name) in base_functions: | ||
if ( | ||
require_min_periods | ||
and (min_periods is not None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are these skipped? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if min_periods is less than the required_min_periods then there will be an error thrown so we wouldn't be able to test the equivalency of the two dfs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there's an error thrown test for that using |
||
and (min_periods < require_min_periods) | ||
): | ||
continue | ||
|
||
base_rolling_f = getattr( | ||
base_df[["data"]].rolling( | ||
window=window, center=center, min_periods=min_periods | ||
), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is this actually testing center? sure you pass it but unless we can see the expected results i don't have any idea whether this is correct. a small set of fixed cases that really show the input and output is much more useful here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair, see response on your previous comment, but here I'm more testing the consistency of That being said, I'm open to creating another test to explicitly tests the correctness of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Haven't digested the details here but it sounds like you should hard code expected results and test against that |
||
name, | ||
) | ||
|
||
grp_rolling_f = getattr( | ||
grp_df[["data"]].rolling( | ||
window=window, center=center, min_periods=min_periods | ||
), | ||
name, | ||
) | ||
|
||
base_result = base_rolling_f().reset_index(drop=True) | ||
grp_result = grp_rolling_f().reset_index() | ||
|
||
a_result = grp_result[grp_result["group"] == "A"][["data"]].reset_index( | ||
drop=True | ||
) | ||
b_result = grp_result[grp_result["group"] == "B"][["data"]].reset_index( | ||
drop=True | ||
) | ||
|
||
tm.assert_frame_equal(base_result, a_result) | ||
tm.assert_frame_equal(base_result, b_result) | ||
|
||
|
||
@pytest.mark.parametrize("window", range(7)) | ||
def test_rolling_corr_with_zero_variance(window): | ||
# GH 18430 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
returning wrong values with partial window