Skip to content

Commit 3896e5e

Browse files
committed
BUG: mean overflows for integer dtypes (fixes #10155)
1 parent 0aceb38 commit 3896e5e

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Bug Fixes
6363
- Bug in ``Categorical`` repr with ``display.width`` of ``None`` in Python 3 (:issue:`10087`)
6464

6565

66+
- Bug in ``mean()`` where integer dtypes can overflow (:issue:`10172`)
6667
- Bug where Panel.from_dict does not set dtype when specified (:issue:`10058`)
6768
- Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`)
6869
- Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`)

pandas/core/nanops.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
is_complex_dtype, is_integer_dtype,
2121
is_bool_dtype, is_object_dtype,
2222
is_datetime64_dtype, is_timedelta64_dtype,
23-
is_datetime_or_timedelta_dtype,
23+
is_datetime_or_timedelta_dtype, _get_dtype,
2424
is_int_or_datetime_dtype, is_any_int_dtype)
2525

2626

@@ -254,8 +254,16 @@ def nansum(values, axis=None, skipna=True):
254254
@bottleneck_switch()
255255
def nanmean(values, axis=None, skipna=True):
256256
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
257-
the_sum = _ensure_numeric(values.sum(axis, dtype=dtype_max))
258-
count = _get_counts(mask, axis)
257+
258+
dtype_sum = dtype_max
259+
dtype_count = np.float64
260+
if is_integer_dtype(dtype):
261+
dtype_sum = np.float64
262+
elif is_float_dtype(dtype):
263+
dtype_sum = dtype
264+
dtype_count = dtype
265+
count = _get_counts(mask, axis, dtype=dtype_count)
266+
the_sum = _ensure_numeric(values.sum(axis, dtype=dtype_sum))
259267

260268
if axis is not None and getattr(the_sum, 'ndim', False):
261269
the_mean = the_sum / count
@@ -557,15 +565,16 @@ def _maybe_arg_null_out(result, axis, mask, skipna):
557565
return result
558566

559567

560-
def _get_counts(mask, axis):
568+
def _get_counts(mask, axis, dtype=float):
569+
dtype = _get_dtype(dtype)
561570
if axis is None:
562-
return float(mask.size - mask.sum())
571+
return dtype.type(mask.size - mask.sum())
563572

564573
count = mask.shape[axis] - mask.sum(axis)
565574
try:
566-
return count.astype(float)
575+
return count.astype(dtype)
567576
except AttributeError:
568-
return np.array(count, dtype=float)
577+
return np.array(count, dtype=dtype)
569578

570579

571580
def _maybe_null_out(result, axis, mask):

pandas/tests/test_nanops.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from pandas.core.common import isnull
8+
from pandas.core.common import isnull, is_integer_dtype
99
import pandas.core.nanops as nanops
1010
import pandas.util.testing as tm
1111

@@ -323,6 +323,32 @@ def test_nanmean(self):
323323
allow_complex=False, allow_obj=False,
324324
allow_str=False, allow_date=False, allow_tdelta=True)
325325

326+
def test_nanmean_overflow(self):
327+
# GH 10155
328+
# In the previous implementation mean can overflow for int dtypes, it
329+
# is now consistent with numpy
330+
from pandas import Series
331+
332+
# numpy < 1.9.0 is not computing this correctly
333+
from distutils.version import LooseVersion
334+
if LooseVersion(np.__version__) >= '1.9.0':
335+
for a in [2 ** 55, -2 ** 55, 20150515061816532]:
336+
s = Series(a, index=range(500), dtype=np.int64)
337+
result = s.mean()
338+
np_result = s.values.mean()
339+
self.assertEqual(result, a)
340+
self.assertEqual(result, np_result)
341+
self.assertTrue(result.dtype == np.float64)
342+
343+
# check returned dtype
344+
for dtype in [np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]:
345+
s = Series(range(10), dtype=dtype)
346+
result = s.mean()
347+
if is_integer_dtype(dtype):
348+
self.assertTrue(result.dtype == np.float64)
349+
else:
350+
self.assertTrue(result.dtype == dtype)
351+
326352
def test_nanmedian(self):
327353
self.check_funs(nanops.nanmedian, np.median,
328354
allow_complex=False, allow_str=False, allow_date=False,

0 commit comments

Comments
 (0)