@@ -848,7 +848,8 @@ def test_centered_axis_validation(self):
848
848
.rolling (window = 3 , center = True , axis = 2 ).mean ())
849
849
850
850
def test_rolling_sum (self ):
851
- self ._check_moment_func (np .sum , name = 'sum' )
851
+ self ._check_moment_func (np .nansum , name = 'sum' ,
852
+ zero_min_periods_equal = False )
852
853
853
854
def test_rolling_count (self ):
854
855
counter = lambda x : np .isfinite (x ).astype (float ).sum ()
@@ -1149,14 +1150,21 @@ def test_rolling_quantile_param(self):
1149
1150
ser .rolling (3 ).quantile ('foo' )
1150
1151
1151
1152
def test_rolling_apply (self ):
1152
- ser = Series ([])
1153
- tm .assert_series_equal (ser ,
1154
- ser .rolling (10 ).apply (lambda x : x .mean ()))
1153
+ # suppress warnings about empty slices, as we are deliberately testing
1154
+ # with a 0-length Series
1155
+ with warnings .catch_warnings ():
1156
+ warnings .filterwarnings ("ignore" ,
1157
+ message = ".*(empty slice|0 for slice).*" ,
1158
+ category = RuntimeWarning )
1159
+
1160
+ ser = Series ([])
1161
+ tm .assert_series_equal (ser ,
1162
+ ser .rolling (10 ).apply (lambda x : x .mean ()))
1155
1163
1156
- def f (x ):
1157
- return x [np .isfinite (x )].mean ()
1164
+ def f (x ):
1165
+ return x [np .isfinite (x )].mean ()
1158
1166
1159
- self ._check_moment_func (np .mean , name = 'apply' , func = f )
1167
+ self ._check_moment_func (np .mean , name = 'apply' , func = f )
1160
1168
1161
1169
# GH 8080
1162
1170
s = Series ([None , None , None ])
@@ -1232,18 +1240,25 @@ def test_rolling_kurt(self):
1232
1240
1233
1241
def _check_moment_func (self , static_comp , name , has_min_periods = True ,
1234
1242
has_center = True , has_time_rule = True ,
1235
- fill_value = None , ** kwargs ):
1243
+ fill_value = None , zero_min_periods_equal = True ,
1244
+ ** kwargs ):
1236
1245
1237
1246
def get_result (obj , window , min_periods = None , center = False ):
1238
1247
r = obj .rolling (window = window , min_periods = min_periods ,
1239
1248
center = center )
1240
1249
return getattr (r , name )(** kwargs )
1241
1250
1242
1251
series_result = get_result (self .series , window = 50 )
1243
- frame_result = get_result (self .frame , window = 50 )
1244
-
1245
1252
assert isinstance (series_result , Series )
1246
- assert type (frame_result ) == DataFrame
1253
+ tm .assert_almost_equal (series_result .iloc [- 1 ],
1254
+ static_comp (self .series [- 50 :]))
1255
+
1256
+ frame_result = get_result (self .frame , window = 50 )
1257
+ assert isinstance (frame_result , DataFrame )
1258
+ tm .assert_series_equal (frame_result .iloc [- 1 , :],
1259
+ self .frame .iloc [- 50 :, :].apply (static_comp ,
1260
+ axis = 0 ),
1261
+ check_names = False )
1247
1262
1248
1263
# check time_rule works
1249
1264
if has_time_rule :
@@ -1274,8 +1289,72 @@ def get_result(obj, window, min_periods=None, center=False):
1274
1289
trunc_frame .apply (static_comp ),
1275
1290
check_names = False )
1276
1291
1277
- # GH 7925
1292
+ # excluding NaNs correctly
1293
+ obj = Series (randn (50 ))
1294
+ obj [:10 ] = np .NaN
1295
+ obj [- 10 :] = np .NaN
1296
+ if has_min_periods :
1297
+ result = get_result (obj , 50 , min_periods = 30 )
1298
+ tm .assert_almost_equal (result .iloc [- 1 ], static_comp (obj [10 :- 10 ]))
1299
+
1300
+ # min_periods is working correctly
1301
+ result = get_result (obj , 20 , min_periods = 15 )
1302
+ assert isna (result .iloc [23 ])
1303
+ assert not isna (result .iloc [24 ])
1304
+
1305
+ assert not isna (result .iloc [- 6 ])
1306
+ assert isna (result .iloc [- 5 ])
1307
+
1308
+ obj2 = Series (randn (20 ))
1309
+ result = get_result (obj2 , 10 , min_periods = 5 )
1310
+ assert isna (result .iloc [3 ])
1311
+ assert notna (result .iloc [4 ])
1312
+
1313
+ if zero_min_periods_equal :
1314
+ # min_periods=0 may be equivalent to min_periods=1
1315
+ result0 = get_result (obj , 20 , min_periods = 0 )
1316
+ result1 = get_result (obj , 20 , min_periods = 1 )
1317
+ tm .assert_almost_equal (result0 , result1 )
1318
+ else :
1319
+ result = get_result (obj , 50 )
1320
+ tm .assert_almost_equal (result .iloc [- 1 ], static_comp (obj [10 :- 10 ]))
1321
+
1322
+ # window larger than series length (#7297)
1323
+ if has_min_periods :
1324
+ for minp in (0 , len (self .series ) - 1 , len (self .series )):
1325
+ result = get_result (self .series , len (self .series ) + 1 ,
1326
+ min_periods = minp )
1327
+ expected = get_result (self .series , len (self .series ),
1328
+ min_periods = minp )
1329
+ nan_mask = isna (result )
1330
+ tm .assert_series_equal (nan_mask , isna (expected ))
1331
+
1332
+ nan_mask = ~ nan_mask
1333
+ tm .assert_almost_equal (result [nan_mask ],
1334
+ expected [nan_mask ])
1335
+ else :
1336
+ result = get_result (self .series , len (self .series ) + 1 )
1337
+ expected = get_result (self .series , len (self .series ))
1338
+ nan_mask = isna (result )
1339
+ tm .assert_series_equal (nan_mask , isna (expected ))
1340
+
1341
+ nan_mask = ~ nan_mask
1342
+ tm .assert_almost_equal (result [nan_mask ], expected [nan_mask ])
1343
+
1344
+ # check center=True
1278
1345
if has_center :
1346
+ if has_min_periods :
1347
+ result = get_result (obj , 20 , min_periods = 15 , center = True )
1348
+ expected = get_result (
1349
+ pd .concat ([obj , Series ([np .NaN ] * 9 )]), 20 ,
1350
+ min_periods = 15 )[9 :].reset_index (drop = True )
1351
+ else :
1352
+ result = get_result (obj , 20 , center = True )
1353
+ expected = get_result (
1354
+ pd .concat ([obj , Series ([np .NaN ] * 9 )]),
1355
+ 20 )[9 :].reset_index (drop = True )
1356
+
1357
+ tm .assert_series_equal (result , expected )
1279
1358
1280
1359
# shifter index
1281
1360
s = ['x%d' % x for x in range (12 )]
0 commit comments