Skip to content

Commit c81e322

Browse files
author
tp
committed
recreate _check_ndarray as _check_series
1 parent 1eca246 commit c81e322

File tree

1 file changed

+110
-8
lines changed

1 file changed

+110
-8
lines changed

pandas/tests/test_window.py

+110-8
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,8 @@ def test_centered_axis_validation(self):
848848
.rolling(window=3, center=True, axis=2).mean())
849849

850850
def test_rolling_sum(self):
851-
self._check_moment_func(np.sum, name='sum')
851+
self._check_moment_func(np.sum, name='sum',
852+
zero_min_periods_equal=False)
852853

853854
def test_rolling_count(self):
854855
counter = lambda x: np.isfinite(x).astype(float).sum()
@@ -1149,14 +1150,21 @@ def test_rolling_quantile_param(self):
11491150
ser.rolling(3).quantile('foo')
11501151

11511152
def test_rolling_apply(self):
1152-
ser = Series([])
1153-
tm.assert_series_equal(ser,
1154-
ser.rolling(10).apply(lambda x: x.mean()))
1153+
# suppress warnings about empty slices, as we are deliberately testing
1154+
# with a 0-length Series
1155+
with warnings.catch_warnings():
1156+
warnings.filterwarnings("ignore",
1157+
message=".*(empty slice|0 for slice).*",
1158+
category=RuntimeWarning)
11551159

1156-
def f(x):
1157-
return x[np.isfinite(x)].mean()
1160+
ser = Series([])
1161+
tm.assert_series_equal(ser,
1162+
ser.rolling(10).apply(lambda x: x.mean()))
11581163

1159-
self._check_moment_func(np.mean, name='apply', func=f)
1164+
def f(x):
1165+
return x[np.isfinite(x)].mean()
1166+
1167+
self._check_moment_func(np.mean, name='apply', func=f)
11601168

11611169
# GH 8080
11621170
s = Series([None, None, None])
@@ -1230,9 +1238,103 @@ def test_rolling_kurt(self):
12301238
self._check_moment_func(lambda x: kurtosis(x, bias=False),
12311239
name='kurt')
12321240

1241+
def _check_series(self, name, static_comp, window=50,
1242+
has_min_periods=True, has_center=True,
1243+
test_stable=False, test_window=True,
1244+
zero_min_periods_equal=True, **kwargs):
1245+
1246+
def get_result(obj, window, min_periods=None, center=False):
1247+
roll = obj.rolling(window, min_periods=min_periods, center=center)
1248+
return getattr(roll, name)(**kwargs)
1249+
1250+
result = get_result(self.series, window)
1251+
tm.assert_almost_equal(result.iloc[-1],
1252+
static_comp(self.series[-50:]))
1253+
1254+
# excluding NaNs correctly
1255+
obj = Series(randn(50))
1256+
obj[:10] = np.NaN
1257+
obj[-10:] = np.NaN
1258+
1259+
if has_min_periods:
1260+
result = get_result(obj, 50, min_periods=30)
1261+
tm.assert_almost_equal(result.iloc[-1], static_comp(obj[10:-10]))
1262+
1263+
# min_periods is working correctly
1264+
result = get_result(obj, 20, min_periods=15)
1265+
assert isna(result.iloc[23])
1266+
assert not isna(result.iloc[24])
1267+
1268+
assert not isna(result.iloc[-6])
1269+
assert isna(result.iloc[-5])
1270+
1271+
obj2 = Series(randn(20))
1272+
result = get_result(obj2, 10, min_periods=5)
1273+
assert isna(result.iloc[3])
1274+
assert notna(result.iloc[4])
1275+
1276+
if zero_min_periods_equal:
1277+
# min_periods=0 may be equivalent to min_periods=1
1278+
result0 = get_result(obj, 20, min_periods=0)
1279+
result1 = get_result(obj, 20, min_periods=1)
1280+
tm.assert_almost_equal(result0, result1)
1281+
else:
1282+
result = get_result(obj, 50)
1283+
tm.assert_almost_equal(result.iloc[-1], static_comp(obj[10:-10]))
1284+
1285+
# GH 7925
1286+
if has_center:
1287+
if has_min_periods:
1288+
result = get_result(obj, 20, min_periods=15, center=True)
1289+
expected = get_result(
1290+
pd.concat([obj, Series([np.NaN] * 9)]), 20,
1291+
min_periods=15)[9:].reset_index(drop=True)
1292+
else:
1293+
result = get_result(obj, 20, center=True)
1294+
expected = get_result(
1295+
pd.concat([obj, Series([np.NaN] * 9)]),
1296+
20)[9:].reset_index(drop=True)
1297+
1298+
tm.assert_series_equal(result, expected)
1299+
1300+
if test_stable:
1301+
result = get_result(self.series + 1e9, window)
1302+
tm.assert_almost_equal(result[-1],
1303+
static_comp(self.arr[-50:] + 1e9))
1304+
1305+
# Test window larger than array, #7297
1306+
if test_window:
1307+
if has_min_periods:
1308+
for minp in (0, len(self.series) - 1, len(self.series)):
1309+
result = get_result(self.series, len(self.series) + 1,
1310+
min_periods=minp)
1311+
expected = get_result(self.series, len(self.series),
1312+
min_periods=minp)
1313+
nan_mask = isna(result)
1314+
tm.assert_series_equal(nan_mask, isna(expected))
1315+
1316+
nan_mask = ~nan_mask
1317+
tm.assert_almost_equal(result[nan_mask],
1318+
expected[nan_mask])
1319+
else:
1320+
result = get_result(self.series, len(self.series) + 1)
1321+
expected = get_result(self.series, len(self.series))
1322+
nan_mask = isna(result)
1323+
tm.assert_series_equal(nan_mask, isna(expected))
1324+
1325+
nan_mask = ~nan_mask
1326+
tm.assert_almost_equal(result[nan_mask], expected[nan_mask])
1327+
12331328
def _check_moment_func(self, static_comp, name, has_min_periods=True,
12341329
has_center=True, has_time_rule=True,
1235-
fill_value=None, **kwargs):
1330+
fill_value=None, zero_min_periods_equal=True,
1331+
**kwargs):
1332+
1333+
self._check_series(name, static_comp=static_comp,
1334+
has_min_periods=has_min_periods,
1335+
has_center=has_center,
1336+
zero_min_periods_equal=zero_min_periods_equal,
1337+
**kwargs)
12361338

12371339
def get_result(obj, window, min_periods=None, center=False):
12381340
r = obj.rolling(window=window, min_periods=min_periods,

0 commit comments

Comments
 (0)