Skip to content

Commit ed4b867

Browse files
authored
BUG: Check min_periods before applying the function (#58886)
* Check min_periods before calling the function * Update whatsnew
1 parent f49b286 commit ed4b867

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

doc/source/whatsnew/v3.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ Groupby/resample/rolling
536536
- Bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
537537
- Bug in :meth:`DataFrameGroupBy.apply` with ``as_index=False`` that was returning :class:`MultiIndex` instead of returning :class:`Index`. (:issue:`58291`)
538538
- Bug in :meth:`DataFrameGroupby.transform` and :meth:`SeriesGroupby.transform` with a reducer and ``observed=False`` that coerces dtype to float when there are unobserved categories. (:issue:`55326`)
539-
539+
- Bug in :meth:`Rolling.apply` where the applied function could be called on fewer than ``min_period`` periods if ``method="table"``. (:issue:`58868`)
540540

541541
Reshaping
542542
^^^^^^^^^

pandas/core/window/numba_.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ def roll_table(
227227
stop = end[i]
228228
window = values[start:stop]
229229
count_nan = np.sum(np.isnan(window), axis=0)
230-
sub_result = numba_func(window, *args)
231230
nan_mask = len(window) - count_nan >= minimum_periods
231+
if nan_mask.any():
232+
result[i, :] = numba_func(window, *args)
232233
min_periods_mask[i, :] = nan_mask
233-
result[i, :] = sub_result
234234
result = np.where(min_periods_mask, result, np.nan)
235235
return result
236236

pandas/tests/window/test_numba.py

+15
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ def f(x, *args):
6767
)
6868
tm.assert_series_equal(result, expected)
6969

70+
def test_numba_min_periods(self):
71+
# GH 58868
72+
def last_row(x):
73+
assert len(x) == 3
74+
return x[-1]
75+
76+
df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]])
77+
78+
result = df.rolling(3, method="table", min_periods=3).apply(
79+
last_row, raw=True, engine="numba"
80+
)
81+
82+
expected = DataFrame([[np.nan, np.nan], [np.nan, np.nan], [5, 6], [7, 8]])
83+
tm.assert_frame_equal(result, expected)
84+
7085
@pytest.mark.parametrize(
7186
"data",
7287
[

0 commit comments

Comments
 (0)