Skip to content

Commit 3b0c542

Browse files
committed
Merge pull request #7342 from jreback/infer
PERF: better dtype inference for perf gains (GH7332)
2 parents c5c4478 + f6e9fff commit 3b0c542

File tree

8 files changed

+118
-49
lines changed

8 files changed

+118
-49
lines changed

doc/source/v0.14.1.txt

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ Enhancements
7878
Performance
7979
~~~~~~~~~~~
8080

81+
- Improvements in dtype inference for numeric operations involving yielding performance gains
82+
for dtypes: ``int64``, ``timedelta64``, ``datetime64`` (:issue:`7223`)
83+
84+
8185
Experimental
8286
~~~~~~~~~~~~
8387

pandas/core/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1753,7 +1753,7 @@ def _possibly_cast_to_datetime(value, dtype, coerce=False):
17531753
elif is_timedelta64:
17541754
from pandas.tseries.timedeltas import \
17551755
_possibly_cast_to_timedelta
1756-
value = _possibly_cast_to_timedelta(value, coerce='compat')
1756+
value = _possibly_cast_to_timedelta(value, coerce='compat', dtype=dtype)
17571757
except:
17581758
pass
17591759

pandas/core/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _convert_to_array(self, values, name=None, other=None):
333333
values = values.to_series()
334334
elif inferred_type in ('timedelta', 'timedelta64'):
335335
# have a timedelta, convert to to ns here
336-
values = _possibly_cast_to_timedelta(values, coerce=coerce)
336+
values = _possibly_cast_to_timedelta(values, coerce=coerce, dtype='timedelta64[ns]')
337337
elif inferred_type == 'integer':
338338
# py3 compat where dtype is 'm' but is an integer
339339
if values.dtype.kind == 'm':

pandas/src/inference.pyx

+42-21
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,47 @@ def is_complex(object obj):
1717
return util.is_complex_object(obj)
1818

1919
_TYPE_MAP = {
20-
np.int8: 'integer',
21-
np.int16: 'integer',
22-
np.int32: 'integer',
23-
np.int64: 'integer',
24-
np.uint8: 'integer',
25-
np.uint16: 'integer',
26-
np.uint32: 'integer',
27-
np.uint64: 'integer',
28-
np.float32: 'floating',
29-
np.float64: 'floating',
30-
np.complex128: 'complex',
31-
np.complex128: 'complex',
32-
np.string_: 'string',
33-
np.unicode_: 'unicode',
34-
np.bool_: 'boolean',
35-
np.datetime64 : 'datetime64',
36-
np.timedelta64 : 'timedelta64'
20+
'int8': 'integer',
21+
'int16': 'integer',
22+
'int32': 'integer',
23+
'int64': 'integer',
24+
'i' : 'integer',
25+
'uint8': 'integer',
26+
'uint16': 'integer',
27+
'uint32': 'integer',
28+
'uint64': 'integer',
29+
'u' : 'integer',
30+
'float32': 'floating',
31+
'float64': 'floating',
32+
'f' : 'floating',
33+
'complex128': 'complex',
34+
'c' : 'complex',
35+
'string': 'string',
36+
'S' : 'string',
37+
'unicode': 'unicode',
38+
'U' : 'unicode',
39+
'bool': 'boolean',
40+
'b' : 'boolean',
41+
'datetime64[ns]' : 'datetime64',
42+
'M' : 'datetime64',
43+
'timedelta64[ns]' : 'timedelta64',
44+
'm' : 'timedelta64',
3745
}
3846

47+
# types only exist on certain platform
3948
try:
40-
_TYPE_MAP[np.float128] = 'floating'
41-
_TYPE_MAP[np.complex256] = 'complex'
42-
_TYPE_MAP[np.float16] = 'floating'
49+
np.float128
50+
_TYPE_MAP['float128'] = 'floating'
51+
except AttributeError:
52+
pass
53+
try:
54+
np.complex256
55+
_TYPE_MAP['complex256'] = 'complex'
56+
except AttributeError:
57+
pass
58+
try:
59+
np.float16
60+
_TYPE_MAP['float16'] = 'floating'
4361
except AttributeError:
4462
pass
4563

@@ -60,7 +78,10 @@ def infer_dtype(object _values):
6078

6179
values = getattr(values, 'values', values)
6280

63-
val_kind = values.dtype.type
81+
val_name = values.dtype.name
82+
if val_name in _TYPE_MAP:
83+
return _TYPE_MAP[val_name]
84+
val_kind = values.dtype.kind
6485
if val_kind in _TYPE_MAP:
6586
return _TYPE_MAP[val_kind]
6687

pandas/tseries/timedeltas.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,13 @@ def convert(r=None, unit=None, m=m):
156156
# no converter
157157
raise ValueError("cannot create timedelta string converter for [{0}]".format(r))
158158

159-
def _possibly_cast_to_timedelta(value, coerce=True):
159+
def _possibly_cast_to_timedelta(value, coerce=True, dtype=None):
160160
""" try to cast to timedelta64, if already a timedeltalike, then make
161161
sure that we are [ns] (as numpy 1.6.2 is very buggy in this regards,
162162
don't force the conversion unless coerce is True
163163
164164
if coerce='compat' force a compatibilty coercerion (to timedeltas) if needeed
165+
if dtype is passed then this is the target dtype
165166
"""
166167

167168
# coercion compatability
@@ -201,10 +202,16 @@ def convert(td, dtype):
201202
return np.array([ convert(v,dtype) for v in value ], dtype='m8[ns]')
202203

203204
# deal with numpy not being able to handle certain timedelta operations
204-
if isinstance(value, (ABCSeries, np.ndarray)) and value.dtype.kind == 'm':
205-
if value.dtype != 'timedelta64[ns]':
205+
if isinstance(value, (ABCSeries, np.ndarray)):
206+
207+
# i8 conversions
208+
if value.dtype == 'int64' and np.dtype(dtype) == 'timedelta64[ns]':
206209
value = value.astype('timedelta64[ns]')
207-
return value
210+
return value
211+
elif value.dtype.kind == 'm':
212+
if value.dtype != 'timedelta64[ns]':
213+
value = value.astype('timedelta64[ns]')
214+
return value
208215

209216
# we don't have a timedelta, but we want to try to convert to one (but
210217
# don't force it)

pandas/tslib.pyx

+22-22
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ cdef inline bint _is_fixed_offset(object tz):
148148
else:
149149
return 0
150150
return 1
151-
151+
152152

153153
_zero_time = datetime_time(0, 0)
154154

@@ -340,7 +340,7 @@ class Timestamp(_Timestamp):
340340
@property
341341
def is_year_end(self):
342342
return self._get_start_end_field('is_year_end')
343-
343+
344344
def tz_localize(self, tz):
345345
"""
346346
Convert naive Timestamp to local time zone
@@ -994,7 +994,7 @@ cdef inline void _localize_tso(_TSObject obj, object tz):
994994
pandas_datetime_to_datetimestruct(obj.value + deltas[0],
995995
PANDAS_FR_ns, &obj.dts)
996996
else:
997-
pandas_datetime_to_datetimestruct(obj.value, PANDAS_FR_ns, &obj.dts)
997+
pandas_datetime_to_datetimestruct(obj.value, PANDAS_FR_ns, &obj.dts)
998998
obj.tzinfo = tz
999999
elif _treat_tz_as_pytz(tz):
10001000
inf = tz._transition_info[pos]
@@ -1044,7 +1044,7 @@ cdef inline object _get_zone(object tz):
10441044
cpdef inline object maybe_get_tz(object tz):
10451045
'''
10461046
(Maybe) Construct a timezone object from a string. If tz is a string, use it to construct a timezone object.
1047-
Otherwise, just return tz.
1047+
Otherwise, just return tz.
10481048
'''
10491049
if isinstance(tz, string_types):
10501050
if tz.startswith('dateutil/'):
@@ -1337,7 +1337,7 @@ def array_to_timedelta64(ndarray[object] values, coerce=False):
13371337
def convert_to_timedelta(object ts, object unit='ns', coerce=False):
13381338
return convert_to_timedelta64(ts, unit, coerce)
13391339

1340-
cdef convert_to_timedelta64(object ts, object unit, object coerce):
1340+
cdef inline convert_to_timedelta64(object ts, object unit, object coerce):
13411341
"""
13421342
Convert an incoming object to a timedelta64 if possible
13431343
@@ -1952,9 +1952,9 @@ cdef inline bint _treat_tz_as_dateutil(object tz):
19521952
cdef inline object _tz_cache_key(object tz):
19531953
"""
19541954
Return the key in the cache for the timezone info object or None if unknown.
1955-
1955+
19561956
The key is currently the tz string for pytz timezones, the filename for dateutil timezones.
1957-
1957+
19581958
Notes
19591959
=====
19601960
This cannot just be the hash of a timezone object. Unfortunately, the hashes of two dateutil tz objects
@@ -2136,7 +2136,7 @@ def tz_localize_to_utc(ndarray[int64_t] vals, object tz, bint infer_dst=False):
21362136
# right side
21372137
idx_shifted = _ensure_int64(
21382138
np.maximum(0, trans.searchsorted(vals + DAY_NS, side='right') - 1))
2139-
2139+
21402140
for i in range(n):
21412141
v = vals[i] - deltas[idx_shifted[i]]
21422142
pos = bisect_right_i8(tdata, v, ntrans) - 1
@@ -2516,7 +2516,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
25162516

25172517
pandas_datetime_to_datetimestruct(dtindex[i], PANDAS_FR_ns, &dts)
25182518
dom = dts.day
2519-
2519+
25202520
if dom == 1:
25212521
out[i] = 1
25222522
return out.view(bool)
@@ -2534,7 +2534,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
25342534
doy = mo_off + dom
25352535
ldom = _month_offset[isleap, dts.month]
25362536
dow = ts_dayofweek(ts)
2537-
2537+
25382538
if (ldom == doy and dow < 5) or (dow == 4 and (ldom - doy <= 2)):
25392539
out[i] = 1
25402540
return out.view(bool)
@@ -2548,9 +2548,9 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
25482548
dom = dts.day
25492549
doy = mo_off + dom
25502550
ldom = _month_offset[isleap, dts.month]
2551-
2551+
25522552
if ldom == doy:
2553-
out[i] = 1
2553+
out[i] = 1
25542554
return out.view(bool)
25552555

25562556
elif field == 'is_quarter_start':
@@ -2564,17 +2564,17 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
25642564
dow = ts_dayofweek(ts)
25652565

25662566
if ((dts.month - start_month) % 3 == 0) and ((dom == 1 and dow < 5) or (dom <= 3 and dow == 0)):
2567-
out[i] = 1
2567+
out[i] = 1
25682568
return out.view(bool)
25692569
else:
25702570
for i in range(count):
25712571
if dtindex[i] == NPY_NAT: out[i] = -1; continue
25722572

25732573
pandas_datetime_to_datetimestruct(dtindex[i], PANDAS_FR_ns, &dts)
25742574
dom = dts.day
2575-
2575+
25762576
if ((dts.month - start_month) % 3 == 0) and dom == 1:
2577-
out[i] = 1
2577+
out[i] = 1
25782578
return out.view(bool)
25792579

25802580
elif field == 'is_quarter_end':
@@ -2590,9 +2590,9 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
25902590
doy = mo_off + dom
25912591
ldom = _month_offset[isleap, dts.month]
25922592
dow = ts_dayofweek(ts)
2593-
2593+
25942594
if ((dts.month - end_month) % 3 == 0) and ((ldom == doy and dow < 5) or (dow == 4 and (ldom - doy <= 2))):
2595-
out[i] = 1
2595+
out[i] = 1
25962596
return out.view(bool)
25972597
else:
25982598
for i in range(count):
@@ -2604,9 +2604,9 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
26042604
dom = dts.day
26052605
doy = mo_off + dom
26062606
ldom = _month_offset[isleap, dts.month]
2607-
2607+
26082608
if ((dts.month - end_month) % 3 == 0) and (ldom == doy):
2609-
out[i] = 1
2609+
out[i] = 1
26102610
return out.view(bool)
26112611

26122612
elif field == 'is_year_start':
@@ -2620,7 +2620,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
26202620
dow = ts_dayofweek(ts)
26212621

26222622
if (dts.month == start_month) and ((dom == 1 and dow < 5) or (dom <= 3 and dow == 0)):
2623-
out[i] = 1
2623+
out[i] = 1
26242624
return out.view(bool)
26252625
else:
26262626
for i in range(count):
@@ -2648,7 +2648,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
26482648
ldom = _month_offset[isleap, dts.month]
26492649

26502650
if (dts.month == end_month) and ((ldom == doy and dow < 5) or (dow == 4 and (ldom - doy <= 2))):
2651-
out[i] = 1
2651+
out[i] = 1
26522652
return out.view(bool)
26532653
else:
26542654
for i in range(count):
@@ -2665,7 +2665,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field, object freqstr=N
26652665
if (dts.month == end_month) and (ldom == doy):
26662666
out[i] = 1
26672667
return out.view(bool)
2668-
2668+
26692669
raise ValueError("Field %s not supported" % field)
26702670

26712671

vb_suite/inference.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from vbench.api import Benchmark
2+
from datetime import datetime
3+
import sys
4+
5+
# from GH 7332
6+
7+
setup = """from pandas_vb_common import *
8+
import pandas as pd
9+
N = 500000
10+
df_int64 = DataFrame(dict(A = np.arange(N,dtype='int64'), B = np.arange(N,dtype='int64')))
11+
df_int32 = DataFrame(dict(A = np.arange(N,dtype='int32'), B = np.arange(N,dtype='int32')))
12+
df_uint32 = DataFrame(dict(A = np.arange(N,dtype='uint32'), B = np.arange(N,dtype='uint32')))
13+
df_float64 = DataFrame(dict(A = np.arange(N,dtype='float64'), B = np.arange(N,dtype='float64')))
14+
df_float32 = DataFrame(dict(A = np.arange(N,dtype='float32'), B = np.arange(N,dtype='float32')))
15+
df_datetime64 = DataFrame(dict(A = pd.to_datetime(np.arange(N,dtype='int64'),unit='ms'),
16+
B = pd.to_datetime(np.arange(N,dtype='int64'),unit='ms')))
17+
df_timedelta64 = DataFrame(dict(A = df_datetime64['A']-df_datetime64['B'],
18+
B = df_datetime64['B']))
19+
"""
20+
21+
dtype_infer_int64 = Benchmark('df_int64["A"] + df_int64["B"]', setup,
22+
start_date=datetime(2014, 1, 1))
23+
dtype_infer_int32 = Benchmark('df_int32["A"] + df_int32["B"]', setup,
24+
start_date=datetime(2014, 1, 1))
25+
dtype_infer_uint32 = Benchmark('df_uint32["A"] + df_uint32["B"]', setup,
26+
start_date=datetime(2014, 1, 1))
27+
dtype_infer_float64 = Benchmark('df_float64["A"] + df_float64["B"]', setup,
28+
start_date=datetime(2014, 1, 1))
29+
dtype_infer_float32 = Benchmark('df_float32["A"] + df_float32["B"]', setup,
30+
start_date=datetime(2014, 1, 1))
31+
dtype_infer_datetime64 = Benchmark('df_datetime64["A"] - df_datetime64["B"]', setup,
32+
start_date=datetime(2014, 1, 1))
33+
dtype_infer_timedelta64_1 = Benchmark('df_timedelta64["A"] + df_timedelta64["B"]', setup,
34+
start_date=datetime(2014, 1, 1))
35+
dtype_infer_timedelta64_2 = Benchmark('df_timedelta64["A"] + df_timedelta64["A"]', setup,
36+
start_date=datetime(2014, 1, 1))

vb_suite/suite.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
'index_object',
1313
'indexing',
1414
'io_bench',
15+
'inference',
1516
'hdfstore_bench',
1617
'join_merge',
1718
'miscellaneous',

0 commit comments

Comments
 (0)