Skip to content

Commit 710148e

Browse files
authored
CLN/REF: Refactor min_periods calculation in rolling (#37156)
1 parent 8cd42e1 commit 710148e

File tree

4 files changed

+31
-79
lines changed

4 files changed

+31
-79
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ Groupby/resample/rolling
462462
- Bug in :meth:`RollingGroupby.count` where a ``ValueError`` was raised when specifying the ``closed`` parameter (:issue:`35869`)
463463
- Bug in :meth:`DataFrame.groupby.rolling` returning wrong values with partial centered window (:issue:`36040`).
464464
- Bug in :meth:`DataFrameGroupBy.rolling` returned wrong values with timeaware window containing ``NaN``. Raises ``ValueError`` because windows are not monotonic now (:issue:`34617`)
465+
- Bug in :meth:`Rolling.__iter__` where a ``ValueError`` was not raised when ``min_periods`` was larger than ``window`` (:issue:`37156`)
465466

466467
Reshaping
467468
^^^^^^^^^

pandas/_libs/window/aggregations.pyx

+10-7
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def roll_var(ndarray[float64_t] values, ndarray[int64_t] start,
369369
ndarray[float64_t] output
370370
bint is_monotonic_bounds
371371

372+
minp = max(minp, 1)
372373
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
373374
output = np.empty(N, dtype=float)
374375

@@ -487,6 +488,7 @@ def roll_skew(ndarray[float64_t] values, ndarray[int64_t] start,
487488
ndarray[float64_t] output
488489
bint is_monotonic_bounds
489490

491+
minp = max(minp, 3)
490492
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
491493
output = np.empty(N, dtype=float)
492494

@@ -611,6 +613,7 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
611613
ndarray[float64_t] output
612614
bint is_monotonic_bounds
613615

616+
minp = max(minp, 4)
614617
is_monotonic_bounds = is_monotonic_start_end_bounds(start, end)
615618
output = np.empty(N, dtype=float)
616619

@@ -655,15 +658,15 @@ def roll_kurt(ndarray[float64_t] values, ndarray[int64_t] start,
655658

656659

657660
def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
658-
ndarray[int64_t] end, int64_t minp, int64_t win=0):
661+
ndarray[int64_t] end, int64_t minp):
659662
# GH 32865. win argument kept for compatibility
660663
cdef:
661664
float64_t val, res, prev
662665
bint err = False
663666
int ret = 0
664667
skiplist_t *sl
665668
Py_ssize_t i, j
666-
int64_t nobs = 0, N = len(values), s, e
669+
int64_t nobs = 0, N = len(values), s, e, win
667670
int midpoint
668671
ndarray[float64_t] output
669672

@@ -721,6 +724,8 @@ def roll_median_c(ndarray[float64_t] values, ndarray[int64_t] start,
721724
else:
722725
res = (skiplist_get(sl, midpoint, &ret) +
723726
skiplist_get(sl, (midpoint - 1), &ret)) / 2
727+
if ret == 0:
728+
res = NaN
724729
else:
725730
res = NaN
726731

@@ -1008,6 +1013,9 @@ def roll_quantile(ndarray[float64_t, cast=True] values, ndarray[int64_t] start,
10081013
vlow = skiplist_get(skiplist, idx, &ret)
10091014
vhigh = skiplist_get(skiplist, idx + 1, &ret)
10101015
output[i] = <float64_t>(vlow + vhigh) / 2
1016+
1017+
if ret == 0:
1018+
output[i] = NaN
10111019
else:
10121020
output[i] = NaN
10131021

@@ -1087,13 +1095,8 @@ cdef ndarray[float64_t] _roll_weighted_sum_mean(float64_t[:] values,
10871095
if avg:
10881096
tot_wgt = np.zeros(in_n, dtype=np.float64)
10891097

1090-
if minp > win_n:
1091-
raise ValueError(f"min_periods (minp) must be <= "
1092-
f"window (win)")
10931098
elif minp > in_n:
10941099
minp = in_n + 1
1095-
elif minp < 0:
1096-
raise ValueError('min_periods must be >= 0')
10971100

10981101
minp = max(minp, 1)
10991102

pandas/core/window/rolling.py

+16-68
Original file line numberDiff line numberDiff line change
@@ -74,41 +74,6 @@
7474
from pandas.core.internals import Block # noqa:F401
7575

7676

77-
def calculate_min_periods(
78-
window: int,
79-
min_periods: Optional[int],
80-
num_values: int,
81-
required_min_periods: int,
82-
floor: int,
83-
) -> int:
84-
"""
85-
Calculate final minimum periods value for rolling aggregations.
86-
87-
Parameters
88-
----------
89-
window : passed window value
90-
min_periods : passed min periods value
91-
num_values : total number of values
92-
required_min_periods : required min periods per aggregation function
93-
floor : required min periods per aggregation function
94-
95-
Returns
96-
-------
97-
min_periods : int
98-
"""
99-
if min_periods is None:
100-
min_periods = window
101-
else:
102-
min_periods = max(required_min_periods, min_periods)
103-
if min_periods > window:
104-
raise ValueError(f"min_periods {min_periods} must be <= window {window}")
105-
elif min_periods > num_values:
106-
min_periods = num_values + 1
107-
elif min_periods < 0:
108-
raise ValueError("min_periods must be >= 0")
109-
return max(min_periods, floor)
110-
111-
11277
class BaseWindow(ShallowMixin, SelectionMixin):
11378
"""Provides utilities for performing windowing operations."""
11479

@@ -163,8 +128,15 @@ def is_freq_type(self) -> bool:
163128
def validate(self) -> None:
164129
if self.center is not None and not is_bool(self.center):
165130
raise ValueError("center must be a boolean")
166-
if self.min_periods is not None and not is_integer(self.min_periods):
167-
raise ValueError("min_periods must be an integer")
131+
if self.min_periods is not None:
132+
if not is_integer(self.min_periods):
133+
raise ValueError("min_periods must be an integer")
134+
elif self.min_periods < 0:
135+
raise ValueError("min_periods must be >= 0")
136+
elif is_integer(self.window) and self.min_periods > self.window:
137+
raise ValueError(
138+
f"min_periods {self.min_periods} must be <= window {self.window}"
139+
)
168140
if self.closed is not None and self.closed not in [
169141
"right",
170142
"both",
@@ -433,8 +405,6 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
433405
def _apply(
434406
self,
435407
func: Callable[..., Any],
436-
require_min_periods: int = 0,
437-
floor: int = 1,
438408
name: Optional[str] = None,
439409
use_numba_cache: bool = False,
440410
**kwargs,
@@ -447,8 +417,6 @@ def _apply(
447417
Parameters
448418
----------
449419
func : callable function to apply
450-
require_min_periods : int
451-
floor : int
452420
name : str,
453421
use_numba_cache : bool
454422
whether to cache a numba compiled function. Only available for numba
@@ -462,6 +430,11 @@ def _apply(
462430
"""
463431
window = self._get_window()
464432
window_indexer = self._get_window_indexer(window)
433+
min_periods = (
434+
self.min_periods
435+
if self.min_periods is not None
436+
else window_indexer.window_size
437+
)
465438

466439
def homogeneous_func(values: np.ndarray):
467440
# calculation function
@@ -470,21 +443,9 @@ def homogeneous_func(values: np.ndarray):
470443
return values.copy()
471444

472445
def calc(x):
473-
if not isinstance(self.window, BaseIndexer):
474-
min_periods = calculate_min_periods(
475-
window, self.min_periods, len(x), require_min_periods, floor
476-
)
477-
else:
478-
min_periods = calculate_min_periods(
479-
window_indexer.window_size,
480-
self.min_periods,
481-
len(x),
482-
require_min_periods,
483-
floor,
484-
)
485446
start, end = window_indexer.get_window_bounds(
486447
num_values=len(x),
487-
min_periods=self.min_periods,
448+
min_periods=min_periods,
488449
center=self.center,
489450
closed=self.closed,
490451
)
@@ -793,16 +754,12 @@ def __init__(self, obj, *args, **kwargs):
793754
def _apply(
794755
self,
795756
func: Callable[..., Any],
796-
require_min_periods: int = 0,
797-
floor: int = 1,
798757
name: Optional[str] = None,
799758
use_numba_cache: bool = False,
800759
**kwargs,
801760
) -> FrameOrSeries:
802761
result = super()._apply(
803762
func,
804-
require_min_periods,
805-
floor,
806763
name,
807764
use_numba_cache,
808765
**kwargs,
@@ -1151,8 +1108,6 @@ def _get_window_weights(
11511108
def _apply(
11521109
self,
11531110
func: Callable[[np.ndarray, int, int], np.ndarray],
1154-
require_min_periods: int = 0,
1155-
floor: int = 1,
11561111
name: Optional[str] = None,
11571112
use_numba_cache: bool = False,
11581113
**kwargs,
@@ -1165,8 +1120,6 @@ def _apply(
11651120
Parameters
11661121
----------
11671122
func : callable function to apply
1168-
require_min_periods : int
1169-
floor : int
11701123
name : str,
11711124
use_numba_cache : bool
11721125
whether to cache a numba compiled function. Only available for numba
@@ -1420,7 +1373,6 @@ def apply(
14201373

14211374
return self._apply(
14221375
apply_func,
1423-
floor=0,
14241376
use_numba_cache=maybe_use_numba(engine),
14251377
original_func=func,
14261378
args=args,
@@ -1454,7 +1406,7 @@ def apply_func(values, begin, end, min_periods, raw=raw):
14541406
def sum(self, *args, **kwargs):
14551407
nv.validate_window_func("sum", args, kwargs)
14561408
window_func = self._get_roll_func("roll_sum")
1457-
return self._apply(window_func, floor=0, name="sum", **kwargs)
1409+
return self._apply(window_func, name="sum", **kwargs)
14581410

14591411
_shared_docs["max"] = dedent(
14601412
"""
@@ -1571,7 +1523,6 @@ def zsqrt_func(values, begin, end, min_periods):
15711523

15721524
return self._apply(
15731525
zsqrt_func,
1574-
require_min_periods=1,
15751526
name="std",
15761527
**kwargs,
15771528
)
@@ -1581,7 +1532,6 @@ def var(self, ddof: int = 1, *args, **kwargs):
15811532
window_func = partial(self._get_roll_func("roll_var"), ddof=ddof)
15821533
return self._apply(
15831534
window_func,
1584-
require_min_periods=1,
15851535
name="var",
15861536
**kwargs,
15871537
)
@@ -1601,7 +1551,6 @@ def skew(self, **kwargs):
16011551
window_func = self._get_roll_func("roll_skew")
16021552
return self._apply(
16031553
window_func,
1604-
require_min_periods=3,
16051554
name="skew",
16061555
**kwargs,
16071556
)
@@ -1695,7 +1644,6 @@ def kurt(self, **kwargs):
16951644
window_func = self._get_roll_func("roll_kurt")
16961645
return self._apply(
16971646
window_func,
1698-
require_min_periods=4,
16991647
name="kurt",
17001648
**kwargs,
17011649
)

pandas/tests/window/test_rolling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def test_rolling_count_default_min_periods_with_null_values(constructor):
498498
({"A": [2, 3], "B": [5, 6]}, [1, 2]),
499499
],
500500
2,
501-
3,
501+
2,
502502
),
503503
(
504504
DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}),
@@ -518,7 +518,7 @@ def test_rolling_count_default_min_periods_with_null_values(constructor):
518518
({"A": [3], "B": [6]}, [2]),
519519
],
520520
1,
521-
2,
521+
0,
522522
),
523523
(DataFrame({"A": [1], "B": [4]}), [], 2, None),
524524
(DataFrame({"A": [1], "B": [4]}), [], 2, 1),
@@ -605,9 +605,9 @@ def test_iter_rolling_on_dataframe(expected, window):
605605
1,
606606
),
607607
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 1),
608-
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 3),
608+
(Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], 2, 2),
609609
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 0),
610-
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 2),
610+
(Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 1),
611611
(Series([1, 2]), [([1], [0]), ([1, 2], [0, 1])], 2, 0),
612612
(Series([], dtype="int64"), [], 2, 1),
613613
],

0 commit comments

Comments
 (0)