4
4
5
5
import pandas as pd
6
6
from pandas import DataFrame , Series , concat
7
- from pandas .tests .window .common import Base , ConsistencyBase
7
+ from pandas .tests .window .common import (
8
+ Base ,
9
+ ConsistencyBase ,
10
+ check_binary_ew ,
11
+ check_binary_ew_min_periods ,
12
+ ew_func ,
13
+ )
8
14
import pandas .util .testing as tm
9
15
10
16
@@ -216,6 +222,9 @@ def _check_ew(self, name=None, preserve_nan=False):
216
222
if preserve_nan :
217
223
assert result [self ._nan_locs ].isna ().all ()
218
224
225
+ @pytest .mark .parametrize ("min_periods" , [0 , 1 ])
226
+ @pytest .mark .parametrize ("name" , ["mean" , "var" , "vol" ])
227
+ def test_ew_min_periods (self , min_periods , name ):
219
228
# excluding NaNs correctly
220
229
arr = randn (50 )
221
230
arr [:10 ] = np .NaN
@@ -228,31 +237,30 @@ def _check_ew(self, name=None, preserve_nan=False):
228
237
assert result [:11 ].isna ().all ()
229
238
assert not result [11 :].isna ().any ()
230
239
231
- for min_periods in (0 , 1 ):
232
- result = getattr (s .ewm (com = 50 , min_periods = min_periods ), name )()
233
- if name == "mean" :
234
- assert result [:10 ].isna ().all ()
235
- assert not result [10 :].isna ().any ()
236
- else :
237
- # ewm.std, ewm.vol, ewm.var (with bias=False) require at least
238
- # two values
239
- assert result [:11 ].isna ().all ()
240
- assert not result [11 :].isna ().any ()
241
-
242
- # check series of length 0
243
- result = getattr (
244
- Series (dtype = object ).ewm (com = 50 , min_periods = min_periods ), name
245
- )()
246
- tm .assert_series_equal (result , Series (dtype = "float64" ))
247
-
248
- # check series of length 1
249
- result = getattr (Series ([1.0 ]).ewm (50 , min_periods = min_periods ), name )()
250
- if name == "mean" :
251
- tm .assert_series_equal (result , Series ([1.0 ]))
252
- else :
253
- # ewm.std, ewm.vol, ewm.var with bias=False require at least
254
- # two values
255
- tm .assert_series_equal (result , Series ([np .NaN ]))
240
+ result = getattr (s .ewm (com = 50 , min_periods = min_periods ), name )()
241
+ if name == "mean" :
242
+ assert result [:10 ].isna ().all ()
243
+ assert not result [10 :].isna ().any ()
244
+ else :
245
+ # ewm.std, ewm.vol, ewm.var (with bias=False) require at least
246
+ # two values
247
+ assert result [:11 ].isna ().all ()
248
+ assert not result [11 :].isna ().any ()
249
+
250
+ # check series of length 0
251
+ result = getattr (
252
+ Series (dtype = object ).ewm (com = 50 , min_periods = min_periods ), name
253
+ )()
254
+ tm .assert_series_equal (result , Series (dtype = "float64" ))
255
+
256
+ # check series of length 1
257
+ result = getattr (Series ([1.0 ]).ewm (50 , min_periods = min_periods ), name )()
258
+ if name == "mean" :
259
+ tm .assert_series_equal (result , Series ([1.0 ]))
260
+ else :
261
+ # ewm.std, ewm.vol, ewm.var with bias=False require at least
262
+ # two values
263
+ tm .assert_series_equal (result , Series ([np .NaN ]))
256
264
257
265
# pass in ints
258
266
result2 = getattr (Series (np .arange (50 )).ewm (span = 10 ), name )()
@@ -263,53 +271,27 @@ class TestEwmMomentsConsistency(ConsistencyBase):
263
271
def setup_method (self , method ):
264
272
self ._create_data ()
265
273
266
- def test_ewmcov (self ):
267
- self ._check_binary_ew ("cov" )
268
-
269
274
def test_ewmcov_pairwise (self ):
270
275
self ._check_pairwise_moment ("ewm" , "cov" , span = 10 , min_periods = 5 )
271
276
272
- def test_ewmcorr (self ):
273
- self ._check_binary_ew ("corr" )
277
+ @pytest .mark .parametrize ("name" , ["cov" , "corr" ])
278
+ def test_ewm_corr_cov (self , name , min_periods , binary_ew_data ):
279
+ A , B = binary_ew_data
280
+
281
+ check_binary_ew (name = "corr" , A = A , B = B )
282
+ check_binary_ew_min_periods ("corr" , min_periods , A , B )
274
283
275
284
def test_ewmcorr_pairwise (self ):
276
285
self ._check_pairwise_moment ("ewm" , "corr" , span = 10 , min_periods = 5 )
277
286
278
- def _check_binary_ew (self , name ):
279
- def func (A , B , com , ** kwargs ):
280
- return getattr (A .ewm (com , ** kwargs ), name )(B )
281
-
282
- A = Series (randn (50 ), index = np .arange (50 ))
283
- B = A [2 :] + randn (48 )
284
-
285
- A [:10 ] = np .NaN
286
- B [- 10 :] = np .NaN
287
-
288
- result = func (A , B , 20 , min_periods = 5 )
289
- assert np .isnan (result .values [:14 ]).all ()
290
- assert not np .isnan (result .values [14 :]).any ()
291
-
292
- # GH 7898
293
- for min_periods in (0 , 1 , 2 ):
294
- result = func (A , B , 20 , min_periods = min_periods )
295
- # binary functions (ewmcov, ewmcorr) with bias=False require at
296
- # least two values
297
- assert np .isnan (result .values [:11 ]).all ()
298
- assert not np .isnan (result .values [11 :]).any ()
299
-
300
- # check series of length 0
301
- empty = Series ([], dtype = np .float64 )
302
- result = func (empty , empty , 50 , min_periods = min_periods )
303
- tm .assert_series_equal (result , empty )
304
-
305
- # check series of length 1
306
- result = func (Series ([1.0 ]), Series ([1.0 ]), 50 , min_periods = min_periods )
307
- tm .assert_series_equal (result , Series ([np .NaN ]))
287
+ @pytest .mark .parametrize ("name" , ["cov" , "corr" ])
288
+ def test_different_input_array_raise_exception (self , name , binary_ew_data ):
308
289
290
+ A , _ = binary_ew_data
309
291
msg = "Input arrays must be of the same type!"
310
292
# exception raised is Exception
311
293
with pytest .raises (Exception , match = msg ):
312
- func (A , randn (50 ), 20 , min_periods = 5 )
294
+ ew_func (A , randn (50 ), 20 , name = name , min_periods = 5 )
313
295
314
296
@pytest .mark .slow
315
297
@pytest .mark .parametrize ("min_periods" , [0 , 1 , 2 , 3 , 4 ])
0 commit comments