Skip to content

Commit f03b0f7

Browse files
committed
make better use of dtype tests in nanops
1 parent 9e1fbc9 commit f03b0f7

File tree

2 files changed

+51
-39
lines changed

2 files changed

+51
-39
lines changed

pandas/core/common.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,14 @@ def is_number(obj):
21752175
return isinstance(obj, (numbers.Number, np.number))
21762176

21772177

2178+
def _get_dtype(arr_or_dtype):
2179+
if isinstance(arr_or_dtype, np.dtype):
2180+
return arr_or_dtype
2181+
if isinstance(arr_or_dtype, type):
2182+
return np.dtype(arr_or_dtype)
2183+
return arr_or_dtype.dtype
2184+
2185+
21782186
def _get_dtype_type(arr_or_dtype):
21792187
if isinstance(arr_or_dtype, np.dtype):
21802188
return arr_or_dtype.type
@@ -2206,7 +2214,7 @@ def is_datetime64_dtype(arr_or_dtype):
22062214

22072215

22082216
def is_datetime64_ns_dtype(arr_or_dtype):
2209-
tipo = _get_dtype_type(arr_or_dtype)
2217+
tipo = _get_dtype(arr_or_dtype)
22102218
return tipo == _NS_DTYPE
22112219

22122220

pandas/core/nanops.py

+42-38
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
import pandas.hashtable as _hash
1515
from pandas import compat, lib, algos, tslib
1616
from pandas.compat import builtins
17-
from pandas.core.common import isnull, notnull, _values_from_object, is_float
17+
from pandas.core.common import (isnull, notnull, _values_from_object,
18+
_maybe_upcast_putmask,
19+
ensure_float, _ensure_float64,
20+
_ensure_int64, _ensure_object,
21+
is_float, is_integer, is_complex,
22+
is_float_dtype, _is_floating_dtype,
23+
is_complex_dtype, is_integer_dtype,
24+
is_bool_dtype, is_object_dtype,
25+
is_datetime64_dtype, is_timedelta64_dtype,
26+
_is_datetime_or_timedelta_dtype,
27+
_is_int_or_datetime_dtype, _is_any_int_dtype)
1828

1929

2030
class disallow(object):
@@ -90,8 +100,8 @@ def f(values, axis=None, skipna=True, **kwds):
90100

91101
def _bn_ok_dtype(dt, name):
92102
# Bottleneck chokes on datetime64
93-
if dt != np.object_ and not issubclass(dt.type, (np.datetime64,
94-
np.timedelta64)):
103+
if (not is_object_dtype(dt) and
104+
not _is_datetime_or_timedelta_dtype(dt)):
95105

96106
# bottleneck does not properly upcast during the sum
97107
# so can overflow
@@ -166,8 +176,7 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
166176

167177
# promote if needed
168178
else:
169-
values, changed = com._maybe_upcast_putmask(values, mask,
170-
fill_value)
179+
values, changed = _maybe_upcast_putmask(values, mask, fill_value)
171180

172181
elif copy:
173182
values = values.copy()
@@ -176,47 +185,42 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
176185

177186
# return a platform independent precision dtype
178187
dtype_max = dtype
179-
if dtype.kind == 'i' and not issubclass(dtype.type, (np.bool,
180-
np.datetime64,
181-
np.timedelta64)):
188+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
182189
dtype_max = np.int64
183-
elif dtype.kind in ['b'] or issubclass(dtype.type, np.bool):
184-
dtype_max = np.int64
185-
elif dtype.kind in ['f']:
190+
elif is_float_dtype(dtype):
186191
dtype_max = np.float64
187192

188193
return values, mask, dtype, dtype_max
189194

190195

191196
def _isfinite(values):
192-
if issubclass(values.dtype.type, (np.timedelta64, np.datetime64)):
197+
if _is_datetime_or_timedelta_dtype(values):
193198
return isnull(values)
194-
elif isinstance(values.dtype, object):
195-
return ~np.isfinite(values.astype('float64'))
196-
197-
return ~np.isfinite(values)
199+
if (is_complex_dtype(values) or is_float_dtype(values) or
200+
is_integer_dtype(values) or is_bool_dtype(values)):
201+
return ~np.isfinite(values)
202+
return ~np.isfinite(values.astype('float64'))
198203

199204

200205
def _na_ok_dtype(dtype):
201-
return not issubclass(dtype.type, (np.integer, np.datetime64,
202-
np.timedelta64))
206+
return not _is_int_or_datetime_dtype(dtype)
203207

204208

205209
def _view_if_needed(values):
206-
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
210+
if _is_datetime_or_timedelta_dtype(values):
207211
return values.view(np.int64)
208212
return values
209213

210214

211215
def _wrap_results(result, dtype):
212216
""" wrap our results if needed """
213217

214-
if issubclass(dtype.type, np.datetime64):
218+
if is_datetime64_dtype(dtype):
215219
if not isinstance(result, np.ndarray):
216220
result = lib.Timestamp(result)
217221
else:
218222
result = result.view(dtype)
219-
elif issubclass(dtype.type, np.timedelta64):
223+
elif is_timedelta64_dtype(dtype):
220224
if not isinstance(result, np.ndarray):
221225

222226
# this is a scalar timedelta result!
@@ -334,7 +338,7 @@ def _get_counts_nanvar(mask, axis, ddof):
334338
@disallow('M8')
335339
@bottleneck_switch(ddof=1)
336340
def nanvar(values, axis=None, skipna=True, ddof=1):
337-
if not isinstance(values.dtype.type, np.floating):
341+
if not _is_floating_dtype(values):
338342
values = values.astype('f8')
339343

340344
mask = isnull(values)
@@ -353,7 +357,7 @@ def nanvar(values, axis=None, skipna=True, ddof=1):
353357
def nansem(values, axis=None, skipna=True, ddof=1):
354358
var = nanvar(values, axis, skipna, ddof=ddof)
355359

356-
if not isinstance(values.dtype.type, np.floating):
360+
if not _is_floating_dtype(values):
357361
values = values.astype('f8')
358362
mask = isnull(values)
359363
count, _ = _get_counts_nanvar(mask, axis, ddof)
@@ -367,7 +371,7 @@ def nanmin(values, axis=None, skipna=True):
367371
fill_value_typ='+inf')
368372

369373
# numpy 1.6.1 workaround in Python 3.x
370-
if (values.dtype == np.object_ and compat.PY3):
374+
if is_object_dtype(values) and compat.PY3:
371375
if values.ndim > 1:
372376
apply_ax = axis if axis is not None else 0
373377
result = np.apply_along_axis(builtins.min, apply_ax, values)
@@ -380,7 +384,7 @@ def nanmin(values, axis=None, skipna=True):
380384
if ((axis is not None and values.shape[axis] == 0)
381385
or values.size == 0):
382386
try:
383-
result = com.ensure_float(values.sum(axis, dtype=dtype_max))
387+
result = ensure_float(values.sum(axis, dtype=dtype_max))
384388
result.fill(np.nan)
385389
except:
386390
result = np.nan
@@ -397,7 +401,7 @@ def nanmax(values, axis=None, skipna=True):
397401
fill_value_typ='-inf')
398402

399403
# numpy 1.6.1 workaround in Python 3.x
400-
if (values.dtype == np.object_ and compat.PY3):
404+
if is_object_dtype(values) and compat.PY3:
401405

402406
if values.ndim > 1:
403407
apply_ax = axis if axis is not None else 0
@@ -411,7 +415,7 @@ def nanmax(values, axis=None, skipna=True):
411415
if ((axis is not None and values.shape[axis] == 0)
412416
or values.size == 0):
413417
try:
414-
result = com.ensure_float(values.sum(axis, dtype=dtype_max))
418+
result = ensure_float(values.sum(axis, dtype=dtype_max))
415419
result.fill(np.nan)
416420
except:
417421
result = np.nan
@@ -446,7 +450,7 @@ def nanargmin(values, axis=None, skipna=True):
446450

447451
@disallow('M8')
448452
def nanskew(values, axis=None, skipna=True):
449-
if not isinstance(values.dtype.type, np.floating):
453+
if not _is_floating_dtype(values):
450454
values = values.astype('f8')
451455

452456
mask = isnull(values)
@@ -480,7 +484,7 @@ def nanskew(values, axis=None, skipna=True):
480484

481485
@disallow('M8')
482486
def nankurt(values, axis=None, skipna=True):
483-
if not isinstance(values.dtype.type, np.floating):
487+
if not _is_floating_dtype(values):
484488
values = values.astype('f8')
485489

486490
mask = isnull(values)
@@ -515,7 +519,7 @@ def nankurt(values, axis=None, skipna=True):
515519
@disallow('M8')
516520
def nanprod(values, axis=None, skipna=True):
517521
mask = isnull(values)
518-
if skipna and not issubclass(values.dtype.type, np.integer):
522+
if skipna and not _is_any_int_dtype(values):
519523
values = values.copy()
520524
values[mask] = 1
521525
result = values.prod(axis)
@@ -644,17 +648,17 @@ def nancov(a, b, min_periods=None):
644648

645649
def _ensure_numeric(x):
646650
if isinstance(x, np.ndarray):
647-
if x.dtype.kind in ['i', 'b']:
651+
if is_integer_dtype(x) or is_bool_dtype(x):
648652
x = x.astype(np.float64)
649-
elif x.dtype == np.object_:
653+
elif is_object_dtype(x):
650654
try:
651655
x = x.astype(np.complex128)
652656
except:
653657
x = x.astype(np.float64)
654658
else:
655659
if not np.any(x.imag):
656660
x = x.real
657-
elif not (com.is_float(x) or com.is_integer(x) or com.is_complex(x)):
661+
elif not (is_float(x) or is_integer(x) or is_complex(x)):
658662
try:
659663
x = float(x)
660664
except Exception:
@@ -678,7 +682,7 @@ def f(x, y):
678682
result = op(x, y)
679683

680684
if mask.any():
681-
if result.dtype == np.bool_:
685+
if is_bool_dtype(result):
682686
result = result.astype('O')
683687
np.putmask(result, mask, np.nan)
684688

@@ -699,16 +703,16 @@ def unique1d(values):
699703
"""
700704
if np.issubdtype(values.dtype, np.floating):
701705
table = _hash.Float64HashTable(len(values))
702-
uniques = np.array(table.unique(com._ensure_float64(values)),
706+
uniques = np.array(table.unique(_ensure_float64(values)),
703707
dtype=np.float64)
704708
elif np.issubdtype(values.dtype, np.datetime64):
705709
table = _hash.Int64HashTable(len(values))
706-
uniques = table.unique(com._ensure_int64(values))
710+
uniques = table.unique(_ensure_int64(values))
707711
uniques = uniques.view('M8[ns]')
708712
elif np.issubdtype(values.dtype, np.integer):
709713
table = _hash.Int64HashTable(len(values))
710-
uniques = table.unique(com._ensure_int64(values))
714+
uniques = table.unique(_ensure_int64(values))
711715
else:
712716
table = _hash.PyObjectHashTable(len(values))
713-
uniques = table.unique(com._ensure_object(values))
717+
uniques = table.unique(_ensure_object(values))
714718
return uniques

0 commit comments

Comments
 (0)