@@ -289,6 +289,193 @@ def test_deprecations(self):
289
289
mom .rolling_mean (Series (np .ones (10 )), 3 , center = True , axis = 0 )
290
290
291
291
292
+ # GH #12373 : rolling functions error on float32 data
293
+ # make sure rolling functions works for different dtypes
294
+ class TestDtype (Base ):
295
+ dtype = None
296
+ window = 2
297
+
298
+ funcs = {
299
+ 'count' : lambda v : v .count (),
300
+ 'max' : lambda v : v .max (),
301
+ 'min' : lambda v : v .min (),
302
+ 'sum' : lambda v : v .sum (),
303
+ 'mean' : lambda v : v .mean (),
304
+ 'std' : lambda v : v .std (),
305
+ 'var' : lambda v : v .var (),
306
+ 'median' : lambda v : v .median ()
307
+ }
308
+
309
+ def get_expects (self ):
310
+ expects = {
311
+ 'sr1' : {
312
+ 'count' : Series ([1 , 2 , 2 , 2 , 2 ], dtype = 'float64' ),
313
+ 'max' : Series ([np .nan , 1 , 2 , 3 , 4 ], dtype = 'float64' ),
314
+ 'min' : Series ([np .nan , 0 , 1 , 2 , 3 ], dtype = 'float64' ),
315
+ 'sum' : Series ([np .nan , 1 , 3 , 5 , 7 ], dtype = 'float64' ),
316
+ 'mean' : Series ([np .nan , .5 , 1.5 , 2.5 , 3.5 ], dtype = 'float64' ),
317
+ 'std' : Series ([np .nan ] + [np .sqrt (.5 )] * 4 , dtype = 'float64' ),
318
+ 'var' : Series ([np .nan , .5 , .5 , .5 , .5 ], dtype = 'float64' ),
319
+ 'median' : Series ([np .nan , .5 , 1.5 , 2.5 , 3.5 ], dtype = 'float64' )
320
+ },
321
+ 'sr2' : {
322
+ 'count' : Series ([1 , 2 , 2 , 2 , 2 ], dtype = 'float64' ),
323
+ 'max' : Series ([np .nan , 10 , 8 , 6 , 4 ], dtype = 'float64' ),
324
+ 'min' : Series ([np .nan , 8 , 6 , 4 , 2 ], dtype = 'float64' ),
325
+ 'sum' : Series ([np .nan , 18 , 14 , 10 , 6 ], dtype = 'float64' ),
326
+ 'mean' : Series ([np .nan , 9 , 7 , 5 , 3 ], dtype = 'float64' ),
327
+ 'std' : Series ([np .nan ] + [np .sqrt (2 )] * 4 , dtype = 'float64' ),
328
+ 'var' : Series ([np .nan , 2 , 2 , 2 , 2 ], dtype = 'float64' ),
329
+ 'median' : Series ([np .nan , 9 , 7 , 5 , 3 ], dtype = 'float64' )
330
+ },
331
+ 'df' : {
332
+ 'count' : DataFrame ({0 : Series ([1 , 2 , 2 , 2 , 2 ]),
333
+ 1 : Series ([1 , 2 , 2 , 2 , 2 ])},
334
+ dtype = 'float64' ),
335
+ 'max' : DataFrame ({0 : Series ([np .nan , 2 , 4 , 6 , 8 ]),
336
+ 1 : Series ([np .nan , 3 , 5 , 7 , 9 ])},
337
+ dtype = 'float64' ),
338
+ 'min' : DataFrame ({0 : Series ([np .nan , 0 , 2 , 4 , 6 ]),
339
+ 1 : Series ([np .nan , 1 , 3 , 5 , 7 ])},
340
+ dtype = 'float64' ),
341
+ 'sum' : DataFrame ({0 : Series ([np .nan , 2 , 6 , 10 , 14 ]),
342
+ 1 : Series ([np .nan , 4 , 8 , 12 , 16 ])},
343
+ dtype = 'float64' ),
344
+ 'mean' : DataFrame ({0 : Series ([np .nan , 1 , 3 , 5 , 7 ]),
345
+ 1 : Series ([np .nan , 2 , 4 , 6 , 8 ])},
346
+ dtype = 'float64' ),
347
+ 'std' : DataFrame ({0 : Series ([np .nan ] + [np .sqrt (2 )] * 4 ),
348
+ 1 : Series ([np .nan ] + [np .sqrt (2 )] * 4 )},
349
+ dtype = 'float64' ),
350
+ 'var' : DataFrame ({0 : Series ([np .nan , 2 , 2 , 2 , 2 ]),
351
+ 1 : Series ([np .nan , 2 , 2 , 2 , 2 ])},
352
+ dtype = 'float64' ),
353
+ 'median' : DataFrame ({0 : Series ([np .nan , 1 , 3 , 5 , 7 ]),
354
+ 1 : Series ([np .nan , 2 , 4 , 6 , 8 ])},
355
+ dtype = 'float64' ),
356
+ }
357
+ }
358
+ return expects
359
+
360
+ def _create_dtype_data (self , dtype ):
361
+ sr1 = Series (range (5 ), dtype = dtype )
362
+ sr2 = Series (range (10 , 0 , - 2 ), dtype = dtype )
363
+ df = DataFrame (np .arange (10 ).reshape ((5 , 2 )), dtype = dtype )
364
+
365
+ data = {
366
+ 'sr1' : sr1 ,
367
+ 'sr2' : sr2 ,
368
+ 'df' : df
369
+ }
370
+
371
+ return data
372
+
373
+ def _create_data (self ):
374
+ super (TestDtype , self )._create_data ()
375
+ self .data = self ._create_dtype_data (self .dtype )
376
+ self .expects = self .get_expects ()
377
+
378
+ def setUp (self ):
379
+ self ._create_data ()
380
+
381
+ def test_dtypes (self ):
382
+ for f_name , d_name in product (self .funcs .keys (), self .data .keys ()):
383
+ f = self .funcs [f_name ]
384
+ d = self .data [d_name ]
385
+ assert_equal = assert_series_equal if isinstance (
386
+ d , Series ) else assert_frame_equal
387
+ exp = self .expects [d_name ][f_name ]
388
+
389
+ roll = d .rolling (window = self .window )
390
+ result = f (roll )
391
+
392
+ assert_equal (result , exp )
393
+
394
+
395
+ class TestDtype_object (TestDtype ):
396
+ dtype = object
397
+
398
+
399
+ class TestDtype_int8 (TestDtype ):
400
+ dtype = np .int8
401
+
402
+
403
+ class TestDtype_int16 (TestDtype ):
404
+ dtype = np .int16
405
+
406
+
407
+ class TestDtype_int32 (TestDtype ):
408
+ dtype = np .int32
409
+
410
+
411
+ class TestDtype_int64 (TestDtype ):
412
+ dtype = np .int64
413
+
414
+
415
+ class TestDtype_uint8 (TestDtype ):
416
+ dtype = np .uint8
417
+
418
+
419
+ class TestDtype_uint16 (TestDtype ):
420
+ dtype = np .uint16
421
+
422
+
423
+ class TestDtype_uint32 (TestDtype ):
424
+ dtype = np .uint32
425
+
426
+
427
+ class TestDtype_uint64 (TestDtype ):
428
+ dtype = np .uint64
429
+
430
+
431
+ class TestDtype_float16 (TestDtype ):
432
+ dtype = np .float16
433
+
434
+
435
+ class TestDtype_float32 (TestDtype ):
436
+ dtype = np .float32
437
+
438
+
439
+ class TestDtype_float64 (TestDtype ):
440
+ dtype = np .float64
441
+
442
+
443
+ class TestDtype_category (TestDtype ):
444
+ dtype = 'category'
445
+ include_df = False
446
+
447
+ def _create_dtype_data (self , dtype ):
448
+ sr1 = Series (range (5 ), dtype = dtype )
449
+ sr2 = Series (range (10 , 0 , - 2 ), dtype = dtype )
450
+
451
+ data = {
452
+ 'sr1' : sr1 ,
453
+ 'sr2' : sr2
454
+ }
455
+
456
+ return data
457
+
458
+
459
+ class TestDatetimeLikeDtype (TestDtype ):
460
+ dtype = np .dtype ('M8[ns]' )
461
+
462
+ # GH #12373: rolling functions raise ValueError on float32 data
463
+ def setUp (self ):
464
+ raise nose .SkipTest ("Skip rolling on DatetimeLike dtypes [{0}]." .format (self .dtype ))
465
+
466
+ def test_dtypes (self ):
467
+ with tm .assertRaises (TypeError ):
468
+ super (TestDatetimeLikeDtype , self ).test_dtypes ()
469
+
470
+
471
+ class TestDtype_timedelta (TestDatetimeLikeDtype ):
472
+ dtype = np .dtype ('m8[ns]' )
473
+
474
+
475
+ class TestDtype_datetime64UTC (TestDatetimeLikeDtype ):
476
+ dtype = 'datetime64[ns, UTC]'
477
+
478
+
292
479
class TestMoments (Base ):
293
480
294
481
def setUp (self ):
0 commit comments