Skip to content

CLN/REF: Refactor min_periods calculation in rolling #37156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 16, 2020
17 changes: 10 additions & 7 deletions pandas/_libs/window/aggregations.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def roll_var(ndarray[float64_t] values, ndarray[int64_t] start,
ndarray[float64_t] output
bint is_monotonic_bounds

minp = max(minp, 1)
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
output = np.empty(N, dtype=float)

Expand Down Expand Up @@ -487,6 +488,7 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
ndarray[float64_t] output
bint is_monotonic_bounds

minp = max(minp, 3)
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
output = np.empty(N, dtype=float)

Expand Down Expand Up @@ -611,6 +613,7 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
ndarray[float64_t] output
bint is_monotonic_bounds

minp = max(minp, 4)
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
output = np.empty(N, dtype=float)

Expand Down Expand Up @@ -655,15 +658,15 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,


def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
ndarray[int64_t] end, int64_t minp, int64_t win=0):
ndarray[int64_t] end, int64_t minp):
# GH 32865. win argument kept for compatibility
cdef:
float64_t val, res, prev
bint err = False
int ret = 0
skiplist_t *sl
Py_ssize_t i, j
int64_t nobs = 0, N = len(values), s, e
int64_t nobs = 0, N = len(values), s, e, win
int midpoint
ndarray[float64_t] output

Expand Down Expand Up @@ -721,6 +724,8 @@ def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
else:
res = (skiplist_get(sl, midpoint, &ret) +
skiplist_get(sl, (midpoint - 1), &ret)) / 2
if ret == 0:
res = NaN
else:
res = NaN

Expand Down Expand Up @@ -1008,6 +1013,9 @@ def roll_quantile(ndarray[float64_t, cast=True] values, ndarray[int64_t] start,
vlow = skiplist_get(skiplist, idx, &ret)
vhigh = skiplist_get(skiplist, idx + 1, &ret)
output[i] = <float64_t>(vlow + vhigh) / 2

if ret == 0:
output[i] = NaN
else:
output[i] = NaN

Expand Down Expand Up @@ -1087,13 +1095,8 @@ cdef ndarray[float64_t] _roll_weighted_sum_mean(float64_t[:] values,
if avg:
tot_wgt = np.zeros(in_n, dtype=np.float64)

if minp > win_n:
raise ValueError(f"min_periods (minp) must be <= "
f"window (win)")
elif minp > in_n:
minp = in_n + 1
elif minp < 0:
raise ValueError('min_periods must be >= 0')

minp = max(minp, 1)

Expand Down
84 changes: 16 additions & 68 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,41 +74,6 @@
from pandas.core.internals import Block # noqa:F401


def calculate_min_periods(
window: int,
min_periods: Optional[int],
num_values: int,
required_min_periods: int,
floor: int,
) -> int:
"""
Calculate final minimum periods value for rolling aggregations.

Parameters
----------
window : passed window value
min_periods : passed min periods value
num_values : total number of values
required_min_periods : required min periods per aggregation function
floor : required min periods per aggregation function

Returns
-------
min_periods : int
"""
if min_periods is None:
min_periods = window
else:
min_periods = max(required_min_periods, min_periods)
if min_periods > window:
raise ValueError(f"min_periods {min_periods} must be <= window {window}")
elif min_periods > num_values:
min_periods = num_values + 1
elif min_periods < 0:
raise ValueError("min_periods must be >= 0")
return max(min_periods, floor)


class BaseWindow(ShallowMixin, SelectionMixin):
"""Provides utilities for performing windowing operations."""

Expand Down Expand Up @@ -163,8 +128,15 @@ def is_freq_type(self) -> bool:
def validate(self) -> None:
if self.center is not None and not is_bool(self.center):
raise ValueError("center must be a boolean")
if self.min_periods is not None and not is_integer(self.min_periods):
raise ValueError("min_periods must be an integer")
if self.min_periods is not None:
if not is_integer(self.min_periods):
raise ValueError("min_periods must be an integer")
elif self.min_periods < 0:
raise ValueError("min_periods must be >= 0")
elif is_integer(self.window) and self.min_periods > self.window:
raise ValueError(
f"min_periods {self.min_periods} must be <= window {self.window}"
)
if self.closed is not None and self.closed not in [
"right",
"both",
Expand Down Expand Up @@ -433,8 +405,6 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
def _apply(
self,
func: Callable[..., Any],
require_min_periods: int = 0,
floor: int = 1,
name: Optional[str] = None,
use_numba_cache: bool = False,
**kwargs,
Expand All @@ -447,8 +417,6 @@ def _apply(
Parameters
----------
func : callable function to apply
require_min_periods : int
floor : int
name : str,
use_numba_cache : bool
whether to cache a numba compiled function. Only available for numba
Expand All @@ -462,6 +430,11 @@ def _apply(
"""
window = self._get_window()
window_indexer = self._get_window_indexer(window)
min_periods = (
self.min_periods
if self.min_periods is not None
else window_indexer.window_size
)

def homogeneous_func(values: np.ndarray):
# calculation function
Expand All @@ -470,21 +443,9 @@ def homogeneous_func(values: np.ndarray):
return values.copy()

def calc(x):
if not isinstance(self.window, BaseIndexer):
min_periods = calculate_min_periods(
window, self.min_periods, len(x), require_min_periods, floor
)
else:
min_periods = calculate_min_periods(
window_indexer.window_size,
self.min_periods,
len(x),
require_min_periods,
floor,
)
start, end = window_indexer.get_window_bounds(
num_values=len(x),
min_periods=self.min_periods,
min_periods=min_periods,
center=self.center,
closed=self.closed,
)
Expand Down Expand Up @@ -793,16 +754,12 @@ def __init__(self, obj, *args, **kwargs):
def _apply(
self,
func: Callable[..., Any],
require_min_periods: int = 0,
floor: int = 1,
name: Optional[str] = None,
use_numba_cache: bool = False,
**kwargs,
) -> FrameOrSeries:
result = super()._apply(
func,
require_min_periods,
floor,
name,
use_numba_cache,
**kwargs,
Expand Down Expand Up @@ -1151,8 +1108,6 @@ def _get_window_weights(
def _apply(
self,
func: Callable[[np.ndarray, int, int], np.ndarray],
require_min_periods: int = 0,
floor: int = 1,
name: Optional[str] = None,
use_numba_cache: bool = False,
**kwargs,
Expand All @@ -1165,8 +1120,6 @@ def _apply(
Parameters
----------
func : callable function to apply
require_min_periods : int
floor : int
name : str,
use_numba_cache : bool
whether to cache a numba compiled function. Only available for numba
Expand Down Expand Up @@ -1420,7 +1373,6 @@ def apply(

return self._apply(
apply_func,
floor=0,
use_numba_cache=maybe_use_numba(engine),
original_func=func,
args=args,
Expand Down Expand Up @@ -1454,7 +1406,7 @@ def apply_func(values, begin, end, min_periods, raw=raw):
def sum(self, *args, **kwargs):
nv.validate_window_func("sum", args, kwargs)
window_func = self._get_roll_func("roll_sum")
return self._apply(window_func, floor=0, name="sum", **kwargs)
return self._apply(window_func, name="sum", **kwargs)

_shared_docs["max"] = dedent(
"""
Expand Down Expand Up @@ -1571,7 +1523,6 @@ def zsqrt_func(values, begin, end, min_periods):

return self._apply(
zsqrt_func,
require_min_periods=1,
name="std",
**kwargs,
)
Expand All @@ -1581,7 +1532,6 @@ def var(self, ddof: int = 1, *args, **kwargs):
window_func = partial(self._get_roll_func("roll_var"), ddof=ddof)
return self._apply(
window_func,
require_min_periods=1,
name="var",
**kwargs,
)
Expand All @@ -1601,7 +1551,6 @@ def skew(self, **kwargs):
window_func = self._get_roll_func("roll_skew")
return self._apply(
window_func,
require_min_periods=3,
name="skew",
**kwargs,
)
Expand Down Expand Up @@ -1695,7 +1644,6 @@ def kurt(self, **kwargs):
window_func = self._get_roll_func("roll_kurt")
return self._apply(
window_func,
require_min_periods=4,
name="kurt",
**kwargs,
)
Expand Down
8 changes: 4 additions & 4 deletions pandas/tests/window/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_rolling_count_default_min_periods_with_null_values(constructor):
({"A": [2, 3], "B": [5, 6]}, [1, 2]),
],
2,
3,
2,
Copy link
Member Author

@mroeschke mroeschke Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test changes are related to the whatsnew entry: This should not have passed previously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what whatsnew entry?

Copy link
Member Author

@mroeschke mroeschke Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thought I added it is previously. Should be there now.

),
(
DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}),
Expand All @@ -518,7 +518,7 @@ def test_rolling_count_default_min_periods_with_null_values(constructor):
({"A": [3], "B": [6]}, [2]),
],
1,
2,
0,
),
(DataFrame({"A": [1], "B": [4]}), [], 2, None),
(DataFrame({"A": [1], "B": [4]}), [], 2, 1),
Expand Down Expand Up @@ -605,9 +605,9 @@ def test_iter_rolling_on_dataframe(expected, window):
1,
),
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 1),
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 3),
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 2),
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 0),
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 2),
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 1),
(Series([1, 2]), [([1], [0]), ([1, 2], [0, 1])], 2, 0),
(Series([], dtype="int64"), [], 2, 1),
],
Expand Down