Skip to content

Commit 1443818

Browse files
author
tp
committed
recreate _check_ndarray inline in _check_moment_func
1 parent 045e47a commit 1443818

File tree

1 file changed

+91
-12
lines changed

1 file changed

+91
-12
lines changed

pandas/tests/test_window.py

+91-12
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,8 @@ def test_centered_axis_validation(self):
848848
.rolling(window=3, center=True, axis=2).mean())
849849

850850
def test_rolling_sum(self):
851-
self._check_moment_func(np.sum, name='sum')
851+
self._check_moment_func(np.nansum, name='sum',
852+
zero_min_periods_equal=False)
852853

853854
def test_rolling_count(self):
854855
counter = lambda x: np.isfinite(x).astype(float).sum()
@@ -1149,14 +1150,21 @@ def test_rolling_quantile_param(self):
11491150
ser.rolling(3).quantile('foo')
11501151

11511152
def test_rolling_apply(self):
1152-
ser = Series([])
1153-
tm.assert_series_equal(ser,
1154-
ser.rolling(10).apply(lambda x: x.mean()))
1153+
# suppress warnings about empty slices, as we are deliberately testing
1154+
# with a 0-length Series
1155+
with warnings.catch_warnings():
1156+
warnings.filterwarnings("ignore",
1157+
message=".*(empty slice|0 for slice).*",
1158+
category=RuntimeWarning)
1159+
1160+
ser = Series([])
1161+
tm.assert_series_equal(ser,
1162+
ser.rolling(10).apply(lambda x: x.mean()))
11551163

1156-
def f(x):
1157-
return x[np.isfinite(x)].mean()
1164+
def f(x):
1165+
return x[np.isfinite(x)].mean()
11581166

1159-
self._check_moment_func(np.mean, name='apply', func=f)
1167+
self._check_moment_func(np.mean, name='apply', func=f)
11601168

11611169
# GH 8080
11621170
s = Series([None, None, None])
@@ -1232,18 +1240,25 @@ def test_rolling_kurt(self):
12321240

12331241
def _check_moment_func(self, static_comp, name, has_min_periods=True,
12341242
has_center=True, has_time_rule=True,
1235-
fill_value=None, **kwargs):
1243+
fill_value=None, zero_min_periods_equal=True,
1244+
**kwargs):
12361245

12371246
def get_result(obj, window, min_periods=None, center=False):
12381247
r = obj.rolling(window=window, min_periods=min_periods,
12391248
center=center)
12401249
return getattr(r, name)(**kwargs)
12411250

12421251
series_result = get_result(self.series, window=50)
1243-
frame_result = get_result(self.frame, window=50)
1244-
12451252
assert isinstance(series_result, Series)
1246-
assert type(frame_result) == DataFrame
1253+
tm.assert_almost_equal(series_result.iloc[-1],
1254+
static_comp(self.series[-50:]))
1255+
1256+
frame_result = get_result(self.frame, window=50)
1257+
assert isinstance(frame_result, DataFrame)
1258+
tm.assert_series_equal(frame_result.iloc[-1, :],
1259+
self.frame.iloc[-50:, :].apply(static_comp,
1260+
axis=0),
1261+
check_names=False)
12471262

12481263
# check time_rule works
12491264
if has_time_rule:
@@ -1274,8 +1289,72 @@ def get_result(obj, window, min_periods=None, center=False):
12741289
trunc_frame.apply(static_comp),
12751290
check_names=False)
12761291

1277-
# GH 7925
1292+
# excluding NaNs correctly
1293+
obj = Series(randn(50))
1294+
obj[:10] = np.NaN
1295+
obj[-10:] = np.NaN
1296+
if has_min_periods:
1297+
result = get_result(obj, 50, min_periods=30)
1298+
tm.assert_almost_equal(result.iloc[-1], static_comp(obj[10:-10]))
1299+
1300+
# min_periods is working correctly
1301+
result = get_result(obj, 20, min_periods=15)
1302+
assert isna(result.iloc[23])
1303+
assert not isna(result.iloc[24])
1304+
1305+
assert not isna(result.iloc[-6])
1306+
assert isna(result.iloc[-5])
1307+
1308+
obj2 = Series(randn(20))
1309+
result = get_result(obj2, 10, min_periods=5)
1310+
assert isna(result.iloc[3])
1311+
assert notna(result.iloc[4])
1312+
1313+
if zero_min_periods_equal:
1314+
# min_periods=0 may be equivalent to min_periods=1
1315+
result0 = get_result(obj, 20, min_periods=0)
1316+
result1 = get_result(obj, 20, min_periods=1)
1317+
tm.assert_almost_equal(result0, result1)
1318+
else:
1319+
result = get_result(obj, 50)
1320+
tm.assert_almost_equal(result.iloc[-1], static_comp(obj[10:-10]))
1321+
1322+
# window larger than series length (#7297)
1323+
if has_min_periods:
1324+
for minp in (0, len(self.series) - 1, len(self.series)):
1325+
result = get_result(self.series, len(self.series) + 1,
1326+
min_periods=minp)
1327+
expected = get_result(self.series, len(self.series),
1328+
min_periods=minp)
1329+
nan_mask = isna(result)
1330+
tm.assert_series_equal(nan_mask, isna(expected))
1331+
1332+
nan_mask = ~nan_mask
1333+
tm.assert_almost_equal(result[nan_mask],
1334+
expected[nan_mask])
1335+
else:
1336+
result = get_result(self.series, len(self.series) + 1)
1337+
expected = get_result(self.series, len(self.series))
1338+
nan_mask = isna(result)
1339+
tm.assert_series_equal(nan_mask, isna(expected))
1340+
1341+
nan_mask = ~nan_mask
1342+
tm.assert_almost_equal(result[nan_mask], expected[nan_mask])
1343+
1344+
# check center=True
12781345
if has_center:
1346+
if has_min_periods:
1347+
result = get_result(obj, 20, min_periods=15, center=True)
1348+
expected = get_result(
1349+
pd.concat([obj, Series([np.NaN] * 9)]), 20,
1350+
min_periods=15)[9:].reset_index(drop=True)
1351+
else:
1352+
result = get_result(obj, 20, center=True)
1353+
expected = get_result(
1354+
pd.concat([obj, Series([np.NaN] * 9)]),
1355+
20)[9:].reset_index(drop=True)
1356+
1357+
tm.assert_series_equal(result, expected)
12791358

12801359
# shifter index
12811360
s = ['x%d' % x for x in range(12)]

0 commit comments

Comments
 (0)