Skip to content

Commit 73f25b1

Browse files
committed
ENH: make sure return dtypes for nan funcs are consistent
1 parent bc7d48f commit 73f25b1

File tree

4 files changed

+57
-43
lines changed

4 files changed

+57
-43
lines changed

doc/source/whatsnew/v0.16.2.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Bug Fixes
5757
- Bug where read_hdf store.select modifies the passed columns list when
5858
multi-indexed (:issue:`7212`)
5959
- Bug in ``Categorical`` repr with ``display.width`` of ``None`` in Python 3 (:issue:`10087`)
60-
60+
- Bug where some of the nan funcs do not have consistent return dtypes (:issue:`10251`)
6161
- Bug in groupby.apply aggregation for Categorical not preserving categories (:issue:`10138`)
6262
- Bug in ``mean()`` where integer dtypes can overflow (:issue:`10172`)
6363
- Bug where Panel.from_dict does not set dtype when specified (:issue:`10058`)

pandas/core/nanops.py

+35-24
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def nanall(values, axis=None, skipna=True):
244244
@bottleneck_switch(zero_value=0)
245245
def nansum(values, axis=None, skipna=True):
246246
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
247-
the_sum = values.sum(axis, dtype=dtype_max)
247+
dtype_sum = dtype_max
248+
if is_float_dtype(dtype):
249+
dtype_sum = dtype
250+
the_sum = values.sum(axis, dtype=dtype_sum)
248251
the_sum = _maybe_null_out(the_sum, axis, mask)
249252

250253
return _wrap_results(the_sum, dtype)
@@ -288,7 +291,7 @@ def get_median(x):
288291
return np.nan
289292
return algos.median(_values_from_object(x[mask]))
290293

291-
if values.dtype != np.float64:
294+
if not is_float_dtype(values):
292295
values = values.astype('f8')
293296
values[mask] = np.nan
294297

@@ -317,10 +320,10 @@ def get_median(x):
317320
return _wrap_results(get_median(values) if notempty else np.nan, dtype)
318321

319322

320-
def _get_counts_nanvar(mask, axis, ddof):
321-
count = _get_counts(mask, axis)
322-
323-
d = count-ddof
323+
def _get_counts_nanvar(mask, axis, ddof, dtype=float):
324+
dtype = _get_dtype(dtype)
325+
count = _get_counts(mask, axis, dtype=dtype)
326+
d = count - dtype.type(ddof)
324327

325328
# always return NaN, never inf
326329
if np.isscalar(count):
@@ -341,15 +344,19 @@ def _nanvar(values, axis=None, skipna=True, ddof=1):
341344
if is_any_int_dtype(values):
342345
values = values.astype('f8')
343346

344-
count, d = _get_counts_nanvar(mask, axis, ddof)
347+
if is_float_dtype(values):
348+
count, d = _get_counts_nanvar(mask, axis, ddof, values.dtype)
349+
else:
350+
count, d = _get_counts_nanvar(mask, axis, ddof)
345351

346352
if skipna:
347353
values = values.copy()
348354
np.putmask(values, mask, 0)
349355

350356
X = _ensure_numeric(values.sum(axis))
351357
XX = _ensure_numeric((values ** 2).sum(axis))
352-
return np.fabs((XX - X ** 2 / count) / d)
358+
result = np.fabs((XX - X * X / count) / d)
359+
return result
353360

354361
@disallow('M8')
355362
@bottleneck_switch(ddof=1)
@@ -375,9 +382,9 @@ def nansem(values, axis=None, skipna=True, ddof=1):
375382
mask = isnull(values)
376383
if not is_float_dtype(values.dtype):
377384
values = values.astype('f8')
378-
count, _ = _get_counts_nanvar(mask, axis, ddof)
385+
count, _ = _get_counts_nanvar(mask, axis, ddof, values.dtype)
379386

380-
return np.sqrt(var)/np.sqrt(count)
387+
return np.sqrt(var) / np.sqrt(count)
381388

382389

383390
@bottleneck_switch()
@@ -469,23 +476,25 @@ def nanskew(values, axis=None, skipna=True):
469476
mask = isnull(values)
470477
if not is_float_dtype(values.dtype):
471478
values = values.astype('f8')
472-
473-
count = _get_counts(mask, axis)
479+
count = _get_counts(mask, axis)
480+
else:
481+
count = _get_counts(mask, axis, dtype=values.dtype)
474482

475483
if skipna:
476484
values = values.copy()
477485
np.putmask(values, mask, 0)
478486

487+
typ = values.dtype.type
479488
A = values.sum(axis) / count
480-
B = (values ** 2).sum(axis) / count - A ** 2
481-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
489+
B = (values ** 2).sum(axis) / count - A ** typ(2)
490+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
482491

483492
# floating point error
484493
B = _zero_out_fperr(B)
485494
C = _zero_out_fperr(C)
486495

487-
result = ((np.sqrt((count ** 2 - count)) * C) /
488-
((count - 2) * np.sqrt(B) ** 3))
496+
result = ((np.sqrt(count * count - count) * C) /
497+
((count - typ(2)) * np.sqrt(B) ** typ(3)))
489498

490499
if isinstance(result, np.ndarray):
491500
result = np.where(B == 0, 0, result)
@@ -504,17 +513,19 @@ def nankurt(values, axis=None, skipna=True):
504513
mask = isnull(values)
505514
if not is_float_dtype(values.dtype):
506515
values = values.astype('f8')
507-
508-
count = _get_counts(mask, axis)
516+
count = _get_counts(mask, axis)
517+
else:
518+
count = _get_counts(mask, axis, dtype=values.dtype)
509519

510520
if skipna:
511521
values = values.copy()
512522
np.putmask(values, mask, 0)
513523

524+
typ = values.dtype.type
514525
A = values.sum(axis) / count
515-
B = (values ** 2).sum(axis) / count - A ** 2
516-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
517-
D = (values ** 4).sum(axis) / count - A ** 4 - 6 * B * A * A - 4 * C * A
526+
B = (values ** 2).sum(axis) / count - A ** typ(2)
527+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
528+
D = (values ** 4).sum(axis) / count - A ** typ(4) - typ(6) * B * A * A - typ(4) * C * A
518529

519530
B = _zero_out_fperr(B)
520531
D = _zero_out_fperr(D)
@@ -526,8 +537,8 @@ def nankurt(values, axis=None, skipna=True):
526537
if B == 0:
527538
return 0
528539

529-
result = (((count * count - 1.) * D / (B * B) - 3 * ((count - 1.) ** 2)) /
530-
((count - 2.) * (count - 3.)))
540+
result = (((count * count - typ(1)) * D / (B * B) - typ(3) * ((count - typ(1)) ** typ(2))) /
541+
((count - typ(2)) * (count - typ(3))))
531542

532543
if isinstance(result, np.ndarray):
533544
result = np.where(B == 0, 0, result)
@@ -598,7 +609,7 @@ def _zero_out_fperr(arg):
598609
if isinstance(arg, np.ndarray):
599610
return np.where(np.abs(arg) < 1e-14, 0, arg)
600611
else:
601-
return 0 if np.abs(arg) < 1e-14 else arg
612+
return arg.dtype.type(0) if np.abs(arg) < 1e-14 else arg
602613

603614

604615
@disallow('M8','m8')

pandas/tests/test_nanops.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import partial
55

66
import numpy as np
7-
7+
from pandas import Series
88
from pandas.core.common import isnull, is_integer_dtype
99
import pandas.core.nanops as nanops
1010
import pandas.util.testing as tm
@@ -327,7 +327,6 @@ def test_nanmean_overflow(self):
327327
# GH 10155
328328
# In the previous implementation mean can overflow for int dtypes, it
329329
# is now consistent with numpy
330-
from pandas import Series
331330

332331
# numpy < 1.9.0 is not computing this correctly
333332
from distutils.version import LooseVersion
@@ -340,14 +339,19 @@ def test_nanmean_overflow(self):
340339
self.assertEqual(result, np_result)
341340
self.assertTrue(result.dtype == np.float64)
342341

343-
# check returned dtype
344-
for dtype in [np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]:
342+
def test_returned_dtype(self):
343+
for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64, np.float128]:
345344
s = Series(range(10), dtype=dtype)
346-
result = s.mean()
347-
if is_integer_dtype(dtype):
348-
self.assertTrue(result.dtype == np.float64)
349-
else:
350-
self.assertTrue(result.dtype == dtype)
345+
group_a = ['mean', 'std', 'var', 'skew', 'kurt']
346+
group_b = ['min', 'max']
347+
for method in group_a + group_b:
348+
result = getattr(s, method)()
349+
if is_integer_dtype(dtype) and method in group_a:
350+
self.assertTrue(result.dtype == np.float64,
351+
"return dtype expected from %s is np.float64, got %s instead" % (method, result.dtype))
352+
else:
353+
self.assertTrue(result.dtype == dtype,
354+
"return dtype expected from %s is %s, got %s instead" % (method, dtype, result.dtype))
351355

352356
def test_nanmedian(self):
353357
self.check_funs(nanops.nanmedian, np.median,

pandas/tests/test_series.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,6 @@ def test_nansum_buglet(self):
528528
assert_almost_equal(result, 1)
529529

530530
def test_overflow(self):
531-
532531
# GH 6915
533532
# overflowing on the smaller int dtypes
534533
for dtype in ['int32','int64']:
@@ -551,25 +550,25 @@ def test_overflow(self):
551550
result = s.max()
552551
self.assertEqual(int(result),v[-1])
553552

554-
for dtype in ['float32','float64']:
555-
v = np.arange(5000000,dtype=dtype)
553+
for dtype in ['float32', 'float64']:
554+
v = np.arange(5000000, dtype=dtype)
556555
s = Series(v)
557556

558557
# no bottleneck
559558
result = s.sum(skipna=False)
560-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
559+
self.assertEqual(result, v.sum(dtype=dtype))
561560
result = s.min(skipna=False)
562-
self.assertTrue(np.allclose(float(result),0.0))
561+
self.assertTrue(np.allclose(float(result), 0.0))
563562
result = s.max(skipna=False)
564-
self.assertTrue(np.allclose(float(result),v[-1]))
563+
self.assertTrue(np.allclose(float(result), v[-1]))
565564

566565
# use bottleneck if available
567566
result = s.sum()
568-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
567+
self.assertEqual(result, v.sum(dtype=dtype))
569568
result = s.min()
570-
self.assertTrue(np.allclose(float(result),0.0))
569+
self.assertTrue(np.allclose(float(result), 0.0))
571570
result = s.max()
572-
self.assertTrue(np.allclose(float(result),v[-1]))
571+
self.assertTrue(np.allclose(float(result), v[-1]))
573572

574573
class SafeForSparse(object):
575574
pass

0 commit comments

Comments
 (0)