39
39
# dtype access #
40
40
# --------------- #
41
41
42
- def _ensure_data (values , dtype = None ):
42
+ def _ensure_data (values , dtype = None , infer = True ):
43
43
"""
44
44
routine to ensure that our data is of the correct
45
45
input dtype for lower-level routines
@@ -57,10 +57,15 @@ def _ensure_data(values, dtype=None):
57
57
values : array-like
58
58
dtype : pandas_dtype, optional
59
59
coerce to this dtype
60
+ infer : boolean, default True
61
+ infer object dtypes
60
62
61
63
Returns
62
64
-------
63
- (ndarray, pandas_dtype, algo dtype as a string)
65
+ (ndarray,
66
+ pandas_dtype,
67
+ algo dtype as a string,
68
+ inferred type as a string or None)
64
69
65
70
"""
66
71
@@ -69,28 +74,40 @@ def _ensure_data(values, dtype=None):
69
74
if is_bool_dtype (values ) or is_bool_dtype (dtype ):
70
75
# we are actually coercing to uint64
71
76
# until our algos suppport uint8 directly (see TODO)
72
- return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64'
77
+ return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64' , None
73
78
elif is_signed_integer_dtype (values ) or is_signed_integer_dtype (dtype ):
74
- return _ensure_int64 (values ), 'int64' , 'int64'
79
+ return _ensure_int64 (values ), 'int64' , 'int64' , None
75
80
elif (is_unsigned_integer_dtype (values ) or
76
81
is_unsigned_integer_dtype (dtype )):
77
- return _ensure_uint64 (values ), 'uint64' , 'uint64'
82
+ return _ensure_uint64 (values ), 'uint64' , 'uint64' , None
78
83
elif is_float_dtype (values ) or is_float_dtype (dtype ):
79
- return _ensure_float64 (values ), 'float64' , 'float64'
84
+ return _ensure_float64 (values ), 'float64' , 'float64' , None
80
85
elif is_object_dtype (values ) and dtype is None :
81
- return _ensure_object (np .asarray (values )), 'object' , 'object'
86
+
87
+ # if we can infer a numeric then do this
88
+ inferred = None
89
+ if infer :
90
+ inferred = lib .infer_dtype (values )
91
+ if inferred in ['integer' ]:
92
+ return _ensure_int64 (values ), 'int64' , 'int64' , inferred
93
+ elif inferred in ['floating' ]:
94
+ return (_ensure_float64 (values ),
95
+ 'float64' , 'float64' , inferred )
96
+
97
+ return (_ensure_object (np .asarray (values )),
98
+ 'object' , 'object' , inferred )
82
99
elif is_complex_dtype (values ) or is_complex_dtype (dtype ):
83
100
84
101
# ignore the fact that we are casting to float
85
102
# which discards complex parts
86
103
with catch_warnings (record = True ):
87
104
values = _ensure_float64 (values )
88
- return values , 'float64' , 'float64'
105
+ return values , 'float64' , 'float64' , None
89
106
90
- except (TypeError , ValueError ):
107
+ except (TypeError , ValueError , OverflowError ):
91
108
# if we are trying to coerce to a dtype
92
109
# and it is incompat this will fall thru to here
93
- return _ensure_object (values ), 'object' , 'object'
110
+ return _ensure_object (values ), 'object' , 'object' , None
94
111
95
112
# datetimelike
96
113
if (needs_i8_conversion (values ) or
@@ -111,7 +128,7 @@ def _ensure_data(values, dtype=None):
111
128
values = DatetimeIndex (values )
112
129
dtype = values .dtype
113
130
114
- return values .asi8 , dtype , 'int64'
131
+ return values .asi8 , dtype , 'int64' , None
115
132
116
133
elif is_categorical_dtype (values ) or is_categorical_dtype (dtype ):
117
134
values = getattr (values , 'values' , values )
@@ -122,11 +139,11 @@ def _ensure_data(values, dtype=None):
122
139
# until our algos suppport int* directly (not all do)
123
140
values = _ensure_int64 (values )
124
141
125
- return values , dtype , 'int64'
142
+ return values , dtype , 'int64' , None
126
143
127
144
# we have failed, return object
128
145
values = np .asarray (values )
129
- return _ensure_object (values ), 'object' , 'object'
146
+ return _ensure_object (values ), 'object' , 'object' , None
130
147
131
148
132
149
def _reconstruct_data (values , dtype , original ):
@@ -150,7 +167,13 @@ def _reconstruct_data(values, dtype, original):
150
167
elif is_datetime64tz_dtype (dtype ) or is_period_dtype (dtype ):
151
168
values = Index (original )._shallow_copy (values , name = None )
152
169
elif dtype is not None :
153
- values = values .astype (dtype )
170
+
171
+ # don't cast to object if we are numeric
172
+ if is_object_dtype (dtype ):
173
+ if not is_numeric_dtype (values ):
174
+ values = values .astype (dtype )
175
+ else :
176
+ values = values .astype (dtype )
154
177
155
178
return values
156
179
@@ -161,7 +184,7 @@ def _ensure_arraylike(values):
161
184
"""
162
185
if not isinstance (values , (np .ndarray , ABCCategorical ,
163
186
ABCIndexClass , ABCSeries )):
164
- values = np .array (values )
187
+ values = np .array (values , dtype = object )
165
188
return values
166
189
167
190
@@ -174,11 +197,13 @@ def _ensure_arraylike(values):
174
197
}
175
198
176
199
177
- def _get_hashtable_algo (values ):
200
+ def _get_hashtable_algo (values , infer = False ):
178
201
"""
179
202
Parameters
180
203
----------
181
204
values : arraylike
205
+ infer : boolean, default False
206
+ infer object dtypes
182
207
183
208
Returns
184
209
-------
@@ -188,12 +213,14 @@ def _get_hashtable_algo(values):
188
213
dtype,
189
214
ndtype)
190
215
"""
191
- values , dtype , ndtype = _ensure_data (values )
216
+ values , dtype , ndtype , inferred = _ensure_data (values , infer = infer )
192
217
193
218
if ndtype == 'object' :
194
219
195
220
# its cheaper to use a String Hash Table than Object
196
- if lib .infer_dtype (values ) in ['string' ]:
221
+ if inferred is None :
222
+ inferred = lib .infer_dtype (values )
223
+ if inferred in ['string' ]:
197
224
ndtype = 'string'
198
225
else :
199
226
ndtype = 'object'
@@ -202,24 +229,43 @@ def _get_hashtable_algo(values):
202
229
return (htable , table , values , dtype , ndtype )
203
230
204
231
205
- def _get_data_algo (values , func_map ):
232
+ def _get_data_algo (values , func_map , dtype = None , infer = False ):
233
+ """
234
+ Parameters
235
+ ----------
236
+ values : array-like
237
+ func_map : an inferred -> function dict
238
+ dtype : dtype, optional
239
+ the requested dtype
240
+ infer : boolean, default False
241
+ infer object dtypes
242
+
243
+ Returns
244
+ -------
245
+ (function,
246
+ values,
247
+ ndtype)
248
+ """
206
249
207
250
if is_categorical_dtype (values ):
208
251
values = values ._values_for_rank ()
209
252
210
- values , dtype , ndtype = _ensure_data (values )
253
+ values , dtype , ndtype , inferred = _ensure_data (
254
+ values , dtype = dtype , infer = infer )
211
255
if ndtype == 'object' :
212
256
213
257
# its cheaper to use a String Hash Table than Object
214
- if lib .infer_dtype (values ) in ['string' ]:
258
+ if inferred is None :
259
+ inferred = lib .infer_dtype (values )
260
+ if inferred in ['string' ]:
215
261
try :
216
262
f = func_map ['string' ]
217
263
except KeyError :
218
264
pass
219
265
220
266
f = func_map .get (ndtype , func_map ['object' ])
221
267
222
- return f , values
268
+ return f , values , ndtype
223
269
224
270
225
271
# --------------- #
@@ -248,7 +294,7 @@ def match(to_match, values, na_sentinel=-1):
248
294
"""
249
295
values = com ._asarray_tuplesafe (values )
250
296
htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
251
- to_match , _ , _ = _ensure_data (to_match , dtype )
297
+ to_match , _ , _ , _ = _ensure_data (to_match , dtype )
252
298
table = htable (min (len (to_match ), 1000000 ))
253
299
table .map_locations (values )
254
300
result = table .lookup (to_match )
@@ -344,7 +390,7 @@ def unique(values):
344
390
return values .unique ()
345
391
346
392
original = values
347
- htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
393
+ htable , _ , values , dtype , ndtype = _get_hashtable_algo (values , infer = False )
348
394
349
395
table = htable (len (values ))
350
396
uniques = table .unique (values )
@@ -389,8 +435,8 @@ def isin(comps, values):
389
435
if not isinstance (values , (ABCIndex , ABCSeries , np .ndarray )):
390
436
values = np .array (list (values ), dtype = 'object' )
391
437
392
- comps , dtype , _ = _ensure_data (comps )
393
- values , _ , _ = _ensure_data (values , dtype = dtype )
438
+ comps , dtype , _ , _ = _ensure_data (comps )
439
+ values , _ , _ , _ = _ensure_data (values , dtype = dtype )
394
440
395
441
# GH11232
396
442
# work-around for numpy < 1.8 and comparisions on py3
@@ -499,7 +545,7 @@ def sort_mixed(values):
499
545
500
546
if sorter is None :
501
547
# mixed types
502
- (hash_klass , _ ), values = _get_data_algo (values , _hashtables )
548
+ (hash_klass , _ ), values , _ = _get_data_algo (values , _hashtables )
503
549
t = hash_klass (len (values ))
504
550
t .map_locations (values )
505
551
sorter = _ensure_platform_int (t .lookup (ordered ))
@@ -545,8 +591,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
545
591
546
592
values = _ensure_arraylike (values )
547
593
original = values
548
- values , dtype , _ = _ensure_data (values )
549
- (hash_klass , vec_klass ), values = _get_data_algo (values , _hashtables )
594
+ values , dtype , _ , _ = _ensure_data (values )
595
+ (hash_klass , vec_klass ), values , _ = _get_data_algo (values , _hashtables )
550
596
551
597
table = hash_klass (size_hint or len (values ))
552
598
uniques = vec_klass ()
@@ -660,7 +706,7 @@ def _value_counts_arraylike(values, dropna):
660
706
"""
661
707
values = _ensure_arraylike (values )
662
708
original = values
663
- values , dtype , ndtype = _ensure_data (values )
709
+ values , dtype , ndtype , inferred = _ensure_data (values )
664
710
665
711
if needs_i8_conversion (dtype ):
666
712
# i8
@@ -711,7 +757,7 @@ def duplicated(values, keep='first'):
711
757
duplicated : ndarray
712
758
"""
713
759
714
- values , dtype , ndtype = _ensure_data (values )
760
+ values , dtype , ndtype , inferred = _ensure_data (values )
715
761
f = getattr (htable , "duplicated_{dtype}" .format (dtype = ndtype ))
716
762
return f (values , keep = keep )
717
763
@@ -741,7 +787,7 @@ def mode(values):
741
787
return Series (values .values .mode (), name = values .name )
742
788
return values .mode ()
743
789
744
- values , dtype , ndtype = _ensure_data (values )
790
+ values , dtype , ndtype , inferred = _ensure_data (values )
745
791
746
792
# TODO: this should support float64
747
793
if ndtype not in ['int64' , 'uint64' , 'object' ]:
@@ -785,11 +831,11 @@ def rank(values, axis=0, method='average', na_option='keep',
785
831
(e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1).
786
832
"""
787
833
if values .ndim == 1 :
788
- f , values = _get_data_algo (values , _rank1d_functions )
834
+ f , values , _ = _get_data_algo (values , _rank1d_functions )
789
835
ranks = f (values , ties_method = method , ascending = ascending ,
790
836
na_option = na_option , pct = pct )
791
837
elif values .ndim == 2 :
792
- f , values = _get_data_algo (values , _rank2d_functions )
838
+ f , values , _ = _get_data_algo (values , _rank2d_functions )
793
839
ranks = f (values , axis = axis , ties_method = method ,
794
840
ascending = ascending , na_option = na_option , pct = pct )
795
841
else :
@@ -1049,7 +1095,7 @@ def compute(self, method):
1049
1095
return dropped [slc ].sort_values (ascending = ascending ).head (n )
1050
1096
1051
1097
# fast method
1052
- arr , _ , _ = _ensure_data (dropped .values )
1098
+ arr , _ , _ , _ = _ensure_data (dropped .values )
1053
1099
if method == 'nlargest' :
1054
1100
arr = - arr
1055
1101
0 commit comments