39
39
# dtype access #
40
40
# --------------- #
41
41
42
- def _ensure_data (values , dtype = None ):
42
+ def _ensure_data (values , dtype = None , infer = True ):
43
43
"""
44
44
routine to ensure that our data is of the correct
45
45
input dtype for lower-level routines
@@ -57,10 +57,15 @@ def _ensure_data(values, dtype=None):
57
57
values : array-like
58
58
dtype : pandas_dtype, optional
59
59
coerce to this dtype
60
+ infer : boolean, default True
61
+ infer object dtypes
60
62
61
63
Returns
62
64
-------
63
- (ndarray, pandas_dtype, algo dtype as a string)
65
+ (ndarray,
66
+ pandas_dtype,
67
+ algo dtype as a string,
68
+ inferred type as a string or None)
64
69
65
70
"""
66
71
@@ -69,28 +74,40 @@ def _ensure_data(values, dtype=None):
69
74
if is_bool_dtype (values ) or is_bool_dtype (dtype ):
70
75
# we are actually coercing to uint64
71
76
# until our algos suppport uint8 directly (see TODO)
72
- return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64'
77
+ return np .asarray (values ).astype ('uint64' ), 'bool' , 'uint64' , None
73
78
elif is_signed_integer_dtype (values ) or is_signed_integer_dtype (dtype ):
74
- return _ensure_int64 (values ), 'int64' , 'int64'
79
+ return _ensure_int64 (values ), 'int64' , 'int64' , None
75
80
elif (is_unsigned_integer_dtype (values ) or
76
81
is_unsigned_integer_dtype (dtype )):
77
- return _ensure_uint64 (values ), 'uint64' , 'uint64'
82
+ return _ensure_uint64 (values ), 'uint64' , 'uint64' , None
78
83
elif is_float_dtype (values ) or is_float_dtype (dtype ):
79
- return _ensure_float64 (values ), 'float64' , 'float64'
84
+ return _ensure_float64 (values ), 'float64' , 'float64' , None
80
85
elif is_object_dtype (values ) and dtype is None :
81
- return _ensure_object (np .asarray (values )), 'object' , 'object'
86
+
87
+ # if we can infer a numeric then do this
88
+ inferred = None
89
+ if infer :
90
+ inferred = lib .infer_dtype (values )
91
+ if inferred in ['integer' ]:
92
+ return _ensure_int64 (values ), 'int64' , 'int64' , inferred
93
+ elif inferred in ['floating' ]:
94
+ return _ensure_float64 (values ),
95
+ 'float64' , 'float64' , inferred
96
+
97
+ return (_ensure_object (np .asarray (values )),
98
+ 'object' , 'object' , inferred )
82
99
elif is_complex_dtype (values ) or is_complex_dtype (dtype ):
83
100
84
101
# ignore the fact that we are casting to float
85
102
# which discards complex parts
86
103
with catch_warnings (record = True ):
87
104
values = _ensure_float64 (values )
88
- return values , 'float64' , 'float64'
105
+ return values , 'float64' , 'float64' , None
89
106
90
- except (TypeError , ValueError ):
107
+ except (TypeError , ValueError , OverflowError ):
91
108
# if we are trying to coerce to a dtype
92
109
# and it is incompat this will fall thru to here
93
- return _ensure_object (values ), 'object' , 'object'
110
+ return _ensure_object (values ), 'object' , 'object' , None
94
111
95
112
# datetimelike
96
113
if (needs_i8_conversion (values ) or
@@ -111,7 +128,7 @@ def _ensure_data(values, dtype=None):
111
128
values = DatetimeIndex (values )
112
129
dtype = values .dtype
113
130
114
- return values .asi8 , dtype , 'int64'
131
+ return values .asi8 , dtype , 'int64' , None
115
132
116
133
elif is_categorical_dtype (values ) or is_categorical_dtype (dtype ):
117
134
values = getattr (values , 'values' , values )
@@ -122,11 +139,11 @@ def _ensure_data(values, dtype=None):
122
139
# until our algos suppport int* directly (not all do)
123
140
values = _ensure_int64 (values )
124
141
125
- return values , dtype , 'int64'
142
+ return values , dtype , 'int64' , None
126
143
127
144
# we have failed, return object
128
145
values = np .asarray (values )
129
- return _ensure_object (values ), 'object' , 'object'
146
+ return _ensure_object (values ), 'object' , 'object' , None
130
147
131
148
132
149
def _reconstruct_data (values , dtype , original ):
@@ -150,7 +167,13 @@ def _reconstruct_data(values, dtype, original):
150
167
elif is_datetime64tz_dtype (dtype ) or is_period_dtype (dtype ):
151
168
values = Index (original )._shallow_copy (values , name = None )
152
169
elif dtype is not None :
153
- values = values .astype (dtype )
170
+
171
+ # don't cast to object if we are numeric
172
+ if is_object_dtype (dtype ):
173
+ if not is_numeric_dtype (values ):
174
+ values = values .astype (dtype )
175
+ else :
176
+ values = values .astype (dtype )
154
177
155
178
return values
156
179
@@ -161,7 +184,7 @@ def _ensure_arraylike(values):
161
184
"""
162
185
if not isinstance (values , (np .ndarray , ABCCategorical ,
163
186
ABCIndexClass , ABCSeries )):
164
- values = np .array (values )
187
+ values = np .array (values , dtype = object )
165
188
return values
166
189
167
190
@@ -174,11 +197,13 @@ def _ensure_arraylike(values):
174
197
}
175
198
176
199
177
- def _get_hashtable_algo (values ):
200
+ def _get_hashtable_algo (values , infer = False ):
178
201
"""
179
202
Parameters
180
203
----------
181
204
values : arraylike
205
+ infer : boolean, default False
206
+ infer object dtypes
182
207
183
208
Returns
184
209
-------
@@ -188,12 +213,12 @@ def _get_hashtable_algo(values):
188
213
dtype,
189
214
ndtype)
190
215
"""
191
- values , dtype , ndtype = _ensure_data (values )
216
+ values , dtype , ndtype , inferred = _ensure_data (values , infer = infer )
192
217
193
218
if ndtype == 'object' :
194
219
195
220
# its cheaper to use a String Hash Table than Object
196
- if lib . infer_dtype ( values ) in ['string' ]:
221
+ if inferred in ['string' ]:
197
222
ndtype = 'string'
198
223
else :
199
224
ndtype = 'object'
@@ -202,24 +227,41 @@ def _get_hashtable_algo(values):
202
227
return (htable , table , values , dtype , ndtype )
203
228
204
229
205
- def _get_data_algo (values , func_map ):
230
+ def _get_data_algo (values , func_map , dtype = None , infer = False ):
231
+ """
232
+ Parameters
233
+ ----------
234
+ values : array-like
235
+ func_map : an inferred -> function dict
236
+ dtype : dtype, optional
237
+ the requested dtype
238
+ infer : boolean, default False
239
+ infer object dtypes
240
+
241
+ Returns
242
+ -------
243
+ (function,
244
+ values,
245
+ ndtype)
246
+ """
206
247
207
248
if is_categorical_dtype (values ):
208
249
values = values ._values_for_rank ()
209
250
210
- values , dtype , ndtype = _ensure_data (values )
251
+ values , dtype , ndtype , inferred = _ensure_data (
252
+ values , dtype = dtype , infer = infer )
211
253
if ndtype == 'object' :
212
254
213
255
# its cheaper to use a String Hash Table than Object
214
- if lib . infer_dtype ( values ) in ['string' ]:
256
+ if inferred in ['string' ]:
215
257
try :
216
258
f = func_map ['string' ]
217
259
except KeyError :
218
260
pass
219
261
220
262
f = func_map .get (ndtype , func_map ['object' ])
221
263
222
- return f , values
264
+ return f , values , ndtype
223
265
224
266
225
267
# --------------- #
@@ -248,7 +290,7 @@ def match(to_match, values, na_sentinel=-1):
248
290
"""
249
291
values = com ._asarray_tuplesafe (values )
250
292
htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
251
- to_match , _ , _ = _ensure_data (to_match , dtype )
293
+ to_match , _ , _ , _ = _ensure_data (to_match , dtype )
252
294
table = htable (min (len (to_match ), 1000000 ))
253
295
table .map_locations (values )
254
296
result = table .lookup (to_match )
@@ -344,7 +386,7 @@ def unique(values):
344
386
return values .unique ()
345
387
346
388
original = values
347
- htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
389
+ htable , _ , values , dtype , ndtype = _get_hashtable_algo (values , infer = False )
348
390
349
391
table = htable (len (values ))
350
392
uniques = table .unique (values )
@@ -389,8 +431,8 @@ def isin(comps, values):
389
431
if not isinstance (values , (ABCIndex , ABCSeries , np .ndarray )):
390
432
values = np .array (list (values ), dtype = 'object' )
391
433
392
- comps , dtype , _ = _ensure_data (comps )
393
- values , _ , _ = _ensure_data (values , dtype = dtype )
434
+ comps , dtype , _ , _ = _ensure_data (comps )
435
+ values , _ , _ , _ = _ensure_data (values , dtype = dtype )
394
436
395
437
# GH11232
396
438
# work-around for numpy < 1.8 and comparisions on py3
@@ -499,7 +541,7 @@ def sort_mixed(values):
499
541
500
542
if sorter is None :
501
543
# mixed types
502
- (hash_klass , _ ), values = _get_data_algo (values , _hashtables )
544
+ (hash_klass , _ ), values , _ = _get_data_algo (values , _hashtables )
503
545
t = hash_klass (len (values ))
504
546
t .map_locations (values )
505
547
sorter = _ensure_platform_int (t .lookup (ordered ))
@@ -545,8 +587,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
545
587
546
588
values = _ensure_arraylike (values )
547
589
original = values
548
- values , dtype , _ = _ensure_data (values )
549
- (hash_klass , vec_klass ), values = _get_data_algo (values , _hashtables )
590
+ values , dtype , _ , _ = _ensure_data (values )
591
+ (hash_klass , vec_klass ), values , _ = _get_data_algo (values , _hashtables )
550
592
551
593
table = hash_klass (size_hint or len (values ))
552
594
uniques = vec_klass ()
@@ -660,7 +702,7 @@ def _value_counts_arraylike(values, dropna):
660
702
"""
661
703
values = _ensure_arraylike (values )
662
704
original = values
663
- values , dtype , ndtype = _ensure_data (values )
705
+ values , dtype , ndtype , inferred = _ensure_data (values )
664
706
665
707
if needs_i8_conversion (dtype ):
666
708
# i8
@@ -711,7 +753,7 @@ def duplicated(values, keep='first'):
711
753
duplicated : ndarray
712
754
"""
713
755
714
- values , dtype , ndtype = _ensure_data (values )
756
+ values , dtype , ndtype , inferred = _ensure_data (values )
715
757
f = getattr (htable , "duplicated_{dtype}" .format (dtype = ndtype ))
716
758
return f (values , keep = keep )
717
759
@@ -741,7 +783,7 @@ def mode(values):
741
783
return Series (values .values .mode (), name = values .name )
742
784
return values .mode ()
743
785
744
- values , dtype , ndtype = _ensure_data (values )
786
+ values , dtype , ndtype , inferred = _ensure_data (values )
745
787
746
788
# TODO: this should support float64
747
789
if ndtype not in ['int64' , 'uint64' , 'object' ]:
@@ -785,11 +827,11 @@ def rank(values, axis=0, method='average', na_option='keep',
785
827
(e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1).
786
828
"""
787
829
if values .ndim == 1 :
788
- f , values = _get_data_algo (values , _rank1d_functions )
830
+ f , values , _ = _get_data_algo (values , _rank1d_functions )
789
831
ranks = f (values , ties_method = method , ascending = ascending ,
790
832
na_option = na_option , pct = pct )
791
833
elif values .ndim == 2 :
792
- f , values = _get_data_algo (values , _rank2d_functions )
834
+ f , values , _ = _get_data_algo (values , _rank2d_functions )
793
835
ranks = f (values , axis = axis , ties_method = method ,
794
836
ascending = ascending , na_option = na_option , pct = pct )
795
837
else :
@@ -1049,7 +1091,7 @@ def compute(self, method):
1049
1091
return dropped [slc ].sort_values (ascending = ascending ).head (n )
1050
1092
1051
1093
# fast method
1052
- arr , _ , _ = _ensure_data (dropped .values )
1094
+ arr , _ , _ , _ = _ensure_data (dropped .values )
1053
1095
if method == 'nlargest' :
1054
1096
arr = - arr
1055
1097
0 commit comments