Skip to content

Commit 254cedb

Browse files
BranYangjreback
authored andcommitted
BUG: rolling functions raise ValueError on float32 data
closes pandas-dev#12373 closes pandas-dev#12376
1 parent 55f21ca commit 254cedb

File tree

3 files changed

+203
-7
lines changed

3 files changed

+203
-7
lines changed

doc/source/whatsnew/v0.18.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ New features
5050
Window functions are now methods
5151
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5252

53-
Window functions have been refactored to be methods on ``Series/DataFrame`` objects, rather than top-level functions, which are now deprecated. This allows these window-type functions, to have a similar API to that of ``.groupby``. See the full documentation :ref:`here <stats.moments>` (:issue:`11603`)
53+
Window functions have been refactored to be methods on ``Series/DataFrame`` objects, rather than top-level functions, which are now deprecated. This allows these window-type functions, to have a similar API to that of ``.groupby``. See the full documentation :ref:`here <stats.moments>` (:issue:`11603`, :issue:`12373`)
54+
5455

5556
.. ipython:: python
5657

pandas/core/window.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,17 @@ def _prep_values(self, values=None, kill_inf=True, how=None):
149149
if values is None:
150150
values = getattr(self._selected_obj, 'values', self._selected_obj)
151151

152-
# coerce dtypes as appropriate
152+
# GH #12373 : rolling functions error on float32 data
153+
# make sure the data is coerced to float64
153154
if com.is_float_dtype(values.dtype):
154-
pass
155+
values = com._ensure_float64(values)
155156
elif com.is_integer_dtype(values.dtype):
156-
values = values.astype(float)
157+
values = com._ensure_float64(values)
157158
elif com.is_timedelta64_dtype(values.dtype):
158-
values = values.view('i8').astype(float)
159+
values = com._ensure_float64(values.view('i8'))
159160
else:
160161
try:
161-
values = values.astype(float)
162+
values = com._ensure_float64(values)
162163
except (ValueError, TypeError):
163164
raise TypeError("cannot handle this type -> {0}"
164165
"".format(values.dtype))
@@ -457,7 +458,9 @@ def _apply(self, func, window=None, center=None, check_minp=None, how=None,
457458

458459
def func(arg, window, min_periods=None):
459460
minp = check_minp(min_periods, window)
460-
return cfunc(arg, window, minp, **kwargs)
461+
# GH #12373: rolling functions error on float32 data
462+
return cfunc(com._ensure_float64(arg),
463+
window, minp, **kwargs)
461464

462465
# calculation function
463466
if center:
@@ -494,6 +497,7 @@ def count(self):
494497
obj = self._convert_freq()
495498
window = self._get_window()
496499
window = min(window, len(obj)) if not self.center else window
500+
497501
try:
498502
converted = np.isfinite(obj).astype(float)
499503
except TypeError:
@@ -657,6 +661,10 @@ def cov(self, other=None, pairwise=None, ddof=1, **kwargs):
657661
window = self._get_window(other)
658662

659663
def _get_cov(X, Y):
664+
# GH #12373 : rolling functions error on float32 data
665+
# to avoid potential overflow, cast the data to float64
666+
X = X.astype('float64')
667+
Y = Y.astype('float64')
660668
mean = lambda x: x.rolling(window, self.min_periods,
661669
center=self.center).mean(**kwargs)
662670
count = (X + Y).rolling(window=window,

pandas/tests/test_window.py

+187
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,193 @@ def test_deprecations(self):
289289
mom.rolling_mean(Series(np.ones(10)), 3, center=True, axis=0)
290290

291291

292+
# GH #12373 : rolling functions error on float32 data
293+
# make sure rolling functions works for different dtypes
294+
class TestDtype(Base):
295+
dtype = None
296+
window = 2
297+
298+
funcs = {
299+
'count': lambda v: v.count(),
300+
'max': lambda v: v.max(),
301+
'min': lambda v: v.min(),
302+
'sum': lambda v: v.sum(),
303+
'mean': lambda v: v.mean(),
304+
'std': lambda v: v.std(),
305+
'var': lambda v: v.var(),
306+
'median': lambda v: v.median()
307+
}
308+
309+
def get_expects(self):
310+
expects = {
311+
'sr1': {
312+
'count': Series([1, 2, 2, 2, 2], dtype='float64'),
313+
'max': Series([np.nan, 1, 2, 3, 4], dtype='float64'),
314+
'min': Series([np.nan, 0, 1, 2, 3], dtype='float64'),
315+
'sum': Series([np.nan, 1, 3, 5, 7], dtype='float64'),
316+
'mean': Series([np.nan, .5, 1.5, 2.5, 3.5], dtype='float64'),
317+
'std': Series([np.nan] + [np.sqrt(.5)] * 4, dtype='float64'),
318+
'var': Series([np.nan, .5, .5, .5, .5], dtype='float64'),
319+
'median': Series([np.nan, .5, 1.5, 2.5, 3.5], dtype='float64')
320+
},
321+
'sr2': {
322+
'count': Series([1, 2, 2, 2, 2], dtype='float64'),
323+
'max': Series([np.nan, 10, 8, 6, 4], dtype='float64'),
324+
'min': Series([np.nan, 8, 6, 4, 2], dtype='float64'),
325+
'sum': Series([np.nan, 18, 14, 10, 6], dtype='float64'),
326+
'mean': Series([np.nan, 9, 7, 5, 3], dtype='float64'),
327+
'std': Series([np.nan] + [np.sqrt(2)] * 4, dtype='float64'),
328+
'var': Series([np.nan, 2, 2, 2, 2], dtype='float64'),
329+
'median': Series([np.nan, 9, 7, 5, 3], dtype='float64')
330+
},
331+
'df': {
332+
'count': DataFrame({0: Series([1, 2, 2, 2, 2]),
333+
1: Series([1, 2, 2, 2, 2])},
334+
dtype='float64'),
335+
'max': DataFrame({0: Series([np.nan, 2, 4, 6, 8]),
336+
1: Series([np.nan, 3, 5, 7, 9])},
337+
dtype='float64'),
338+
'min': DataFrame({0: Series([np.nan, 0, 2, 4, 6]),
339+
1: Series([np.nan, 1, 3, 5, 7])},
340+
dtype='float64'),
341+
'sum': DataFrame({0: Series([np.nan, 2, 6, 10, 14]),
342+
1: Series([np.nan, 4, 8, 12, 16])},
343+
dtype='float64'),
344+
'mean': DataFrame({0: Series([np.nan, 1, 3, 5, 7]),
345+
1: Series([np.nan, 2, 4, 6, 8])},
346+
dtype='float64'),
347+
'std': DataFrame({0: Series([np.nan] + [np.sqrt(2)] * 4),
348+
1: Series([np.nan] + [np.sqrt(2)] * 4)},
349+
dtype='float64'),
350+
'var': DataFrame({0: Series([np.nan, 2, 2, 2, 2]),
351+
1: Series([np.nan, 2, 2, 2, 2])},
352+
dtype='float64'),
353+
'median': DataFrame({0: Series([np.nan, 1, 3, 5, 7]),
354+
1: Series([np.nan, 2, 4, 6, 8])},
355+
dtype='float64'),
356+
}
357+
}
358+
return expects
359+
360+
def _create_dtype_data(self, dtype):
361+
sr1 = Series(range(5), dtype=dtype)
362+
sr2 = Series(range(10, 0, -2), dtype=dtype)
363+
df = DataFrame(np.arange(10).reshape((5, 2)), dtype=dtype)
364+
365+
data = {
366+
'sr1': sr1,
367+
'sr2': sr2,
368+
'df': df
369+
}
370+
371+
return data
372+
373+
def _create_data(self):
374+
super(TestDtype, self)._create_data()
375+
self.data = self._create_dtype_data(self.dtype)
376+
self.expects = self.get_expects()
377+
378+
def setUp(self):
379+
self._create_data()
380+
381+
def test_dtypes(self):
382+
for f_name, d_name in product(self.funcs.keys(), self.data.keys()):
383+
f = self.funcs[f_name]
384+
d = self.data[d_name]
385+
assert_equal = assert_series_equal if isinstance(
386+
d, Series) else assert_frame_equal
387+
exp = self.expects[d_name][f_name]
388+
389+
roll = d.rolling(window=self.window)
390+
result = f(roll)
391+
392+
assert_equal(result, exp)
393+
394+
395+
class TestDtype_object(TestDtype):
396+
dtype = object
397+
398+
399+
class TestDtype_int8(TestDtype):
400+
dtype = np.int8
401+
402+
403+
class TestDtype_int16(TestDtype):
404+
dtype = np.int16
405+
406+
407+
class TestDtype_int32(TestDtype):
408+
dtype = np.int32
409+
410+
411+
class TestDtype_int64(TestDtype):
412+
dtype = np.int64
413+
414+
415+
class TestDtype_uint8(TestDtype):
416+
dtype = np.uint8
417+
418+
419+
class TestDtype_uint16(TestDtype):
420+
dtype = np.uint16
421+
422+
423+
class TestDtype_uint32(TestDtype):
424+
dtype = np.uint32
425+
426+
427+
class TestDtype_uint64(TestDtype):
428+
dtype = np.uint64
429+
430+
431+
class TestDtype_float16(TestDtype):
432+
dtype = np.float16
433+
434+
435+
class TestDtype_float32(TestDtype):
436+
dtype = np.float32
437+
438+
439+
class TestDtype_float64(TestDtype):
440+
dtype = np.float64
441+
442+
443+
class TestDtype_category(TestDtype):
444+
dtype = 'category'
445+
include_df = False
446+
447+
def _create_dtype_data(self, dtype):
448+
sr1 = Series(range(5), dtype=dtype)
449+
sr2 = Series(range(10, 0, -2), dtype=dtype)
450+
451+
data = {
452+
'sr1': sr1,
453+
'sr2': sr2
454+
}
455+
456+
return data
457+
458+
459+
class TestDatetimeLikeDtype(TestDtype):
460+
dtype = np.dtype('M8[ns]')
461+
462+
# GH #12373: rolling functions raise ValueError on float32 data
463+
def setUp(self):
464+
raise nose.SkipTest("Skip rolling on DatetimeLike dtypes [{0}].".format(self.dtype))
465+
466+
def test_dtypes(self):
467+
with tm.assertRaises(TypeError):
468+
super(TestDatetimeLikeDtype, self).test_dtypes()
469+
470+
471+
class TestDtype_timedelta(TestDatetimeLikeDtype):
472+
dtype = np.dtype('m8[ns]')
473+
474+
475+
class TestDtype_datetime64UTC(TestDatetimeLikeDtype):
476+
dtype = 'datetime64[ns, UTC]'
477+
478+
292479
class TestMoments(Base):
293480

294481
def setUp(self):

0 commit comments

Comments
 (0)