47
47
from pandas .core .dtypes .missing import isna , na_value_for_dtype
48
48
49
49
from pandas .core import common as com
50
- from pandas .core .construction import array
50
+ from pandas .core .construction import array , extract_array
51
51
from pandas .core .indexers import validate_indices
52
52
53
53
_shared_docs = {} # type: Dict[str, str]
@@ -82,9 +82,12 @@ def _ensure_data(values, dtype=None):
82
82
"""
83
83
84
84
# we check some simple dtypes first
85
+ if is_object_dtype (dtype ):
86
+ return ensure_object (np .asarray (values )), "object" , "object"
87
+ elif is_object_dtype (values ) and dtype is None :
88
+ return ensure_object (np .asarray (values )), "object" , "object"
89
+
85
90
try :
86
- if is_object_dtype (dtype ):
87
- return ensure_object (np .asarray (values )), "object" , "object"
88
91
if is_bool_dtype (values ) or is_bool_dtype (dtype ):
89
92
# we are actually coercing to uint64
90
93
# until our algos support uint8 directly (see TODO)
@@ -95,8 +98,6 @@ def _ensure_data(values, dtype=None):
95
98
return ensure_uint64 (values ), "uint64" , "uint64"
96
99
elif is_float_dtype (values ) or is_float_dtype (dtype ):
97
100
return ensure_float64 (values ), "float64" , "float64"
98
- elif is_object_dtype (values ) and dtype is None :
99
- return ensure_object (np .asarray (values )), "object" , "object"
100
101
elif is_complex_dtype (values ) or is_complex_dtype (dtype ):
101
102
102
103
# ignore the fact that we are casting to float
@@ -207,11 +208,11 @@ def _ensure_arraylike(values):
207
208
208
209
209
210
_hashtables = {
210
- "float64" : ( htable .Float64HashTable , htable . Float64Vector ) ,
211
- "uint64" : ( htable .UInt64HashTable , htable . UInt64Vector ) ,
212
- "int64" : ( htable .Int64HashTable , htable . Int64Vector ) ,
213
- "string" : ( htable .StringHashTable , htable . ObjectVector ) ,
214
- "object" : ( htable .PyObjectHashTable , htable . ObjectVector ) ,
211
+ "float64" : htable .Float64HashTable ,
212
+ "uint64" : htable .UInt64HashTable ,
213
+ "int64" : htable .Int64HashTable ,
214
+ "string" : htable .StringHashTable ,
215
+ "object" : htable .PyObjectHashTable ,
215
216
}
216
217
217
218
@@ -223,11 +224,9 @@ def _get_hashtable_algo(values):
223
224
224
225
Returns
225
226
-------
226
- tuples(hashtable class,
227
- vector class,
228
- values,
229
- dtype,
230
- ndtype)
227
+ htable : HashTable subclass
228
+ values : ndarray
229
+ dtype : str or dtype
231
230
"""
232
231
values , dtype , ndtype = _ensure_data (values )
233
232
@@ -238,23 +237,21 @@ def _get_hashtable_algo(values):
238
237
# StringHashTable and ObjectHashtable
239
238
if lib .infer_dtype (values , skipna = False ) in ["string" ]:
240
239
ndtype = "string"
241
- else :
242
- ndtype = "object"
243
240
244
- htable , table = _hashtables [ndtype ]
245
- return ( htable , table , values , dtype , ndtype )
241
+ htable = _hashtables [ndtype ]
242
+ return htable , values , dtype
246
243
247
244
248
245
def _get_values_for_rank (values ):
249
246
if is_categorical_dtype (values ):
250
247
values = values ._values_for_rank ()
251
248
252
- values , dtype , ndtype = _ensure_data (values )
253
- return values , dtype , ndtype
249
+ values , _ , ndtype = _ensure_data (values )
250
+ return values , ndtype
254
251
255
252
256
- def _get_data_algo (values , func_map ):
257
- values , dtype , ndtype = _get_values_for_rank (values )
253
+ def _get_data_algo (values ):
254
+ values , ndtype = _get_values_for_rank (values )
258
255
259
256
if ndtype == "object" :
260
257
@@ -264,7 +261,7 @@ def _get_data_algo(values, func_map):
264
261
if lib .infer_dtype (values , skipna = False ) in ["string" ]:
265
262
ndtype = "string"
266
263
267
- f = func_map .get (ndtype , func_map ["object" ])
264
+ f = _hashtables .get (ndtype , _hashtables ["object" ])
268
265
269
266
return f , values
270
267
@@ -295,7 +292,7 @@ def match(to_match, values, na_sentinel=-1):
295
292
match : ndarray of integers
296
293
"""
297
294
values = com .asarray_tuplesafe (values )
298
- htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
295
+ htable , values , dtype = _get_hashtable_algo (values )
299
296
to_match , _ , _ = _ensure_data (to_match , dtype )
300
297
table = htable (min (len (to_match ), 1000000 ))
301
298
table .map_locations (values )
@@ -398,7 +395,7 @@ def unique(values):
398
395
return values .unique ()
399
396
400
397
original = values
401
- htable , _ , values , dtype , ndtype = _get_hashtable_algo (values )
398
+ htable , values , _ = _get_hashtable_algo (values )
402
399
403
400
table = htable (len (values ))
404
401
uniques = table .unique (values )
@@ -480,7 +477,8 @@ def isin(comps, values):
480
477
481
478
482
479
def _factorize_array (values , na_sentinel = - 1 , size_hint = None , na_value = None ):
483
- """Factorize an array-like to labels and uniques.
480
+ """
481
+ Factorize an array-like to labels and uniques.
484
482
485
483
This doesn't do any coercion of types or unboxing before factorization.
486
484
@@ -498,9 +496,10 @@ def _factorize_array(values, na_sentinel=-1, size_hint=None, na_value=None):
498
496
499
497
Returns
500
498
-------
501
- labels, uniques : ndarray
499
+ labels : ndarray
500
+ uniques : ndarray
502
501
"""
503
- ( hash_klass , _ ), values = _get_data_algo (values , _hashtables )
502
+ hash_klass , values = _get_data_algo (values )
504
503
505
504
table = hash_klass (size_hint or len (values ))
506
505
uniques , labels = table .factorize (
@@ -652,17 +651,13 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
652
651
original = values
653
652
654
653
if is_extension_array_dtype (values ):
655
- values = getattr ( values , "_values" , values )
654
+ values = extract_array ( values )
656
655
labels , uniques = values .factorize (na_sentinel = na_sentinel )
657
656
dtype = original .dtype
658
657
else :
659
658
values , dtype , _ = _ensure_data (values )
660
659
661
- if (
662
- is_datetime64_any_dtype (original )
663
- or is_timedelta64_dtype (original )
664
- or is_period_dtype (original )
665
- ):
660
+ if original .dtype .kind in ["m" , "M" ]:
666
661
na_value = na_value_for_dtype (original .dtype )
667
662
else :
668
663
na_value = None
@@ -835,7 +830,7 @@ def duplicated(values, keep="first"):
835
830
duplicated : ndarray
836
831
"""
837
832
838
- values , dtype , ndtype = _ensure_data (values )
833
+ values , _ , ndtype = _ensure_data (values )
839
834
f = getattr (htable , "duplicated_{dtype}" .format (dtype = ndtype ))
840
835
return f (values , keep = keep )
841
836
@@ -872,7 +867,7 @@ def mode(values, dropna: bool = True):
872
867
mask = values .isnull ()
873
868
values = values [~ mask ]
874
869
875
- values , dtype , ndtype = _ensure_data (values )
870
+ values , _ , ndtype = _ensure_data (values )
876
871
877
872
f = getattr (htable , "mode_{dtype}" .format (dtype = ndtype ))
878
873
result = f (values , dropna = dropna )
@@ -910,7 +905,7 @@ def rank(values, axis=0, method="average", na_option="keep", ascending=True, pct
910
905
(e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1).
911
906
"""
912
907
if values .ndim == 1 :
913
- values , _ , _ = _get_values_for_rank (values )
908
+ values , _ = _get_values_for_rank (values )
914
909
ranks = algos .rank_1d (
915
910
values ,
916
911
ties_method = method ,
@@ -919,7 +914,7 @@ def rank(values, axis=0, method="average", na_option="keep", ascending=True, pct
919
914
pct = pct ,
920
915
)
921
916
elif values .ndim == 2 :
922
- values , _ , _ = _get_values_for_rank (values )
917
+ values , _ = _get_values_for_rank (values )
923
918
ranks = algos .rank_2d (
924
919
values ,
925
920
axis = axis ,
@@ -1634,9 +1629,7 @@ def take_nd(
1634
1629
if is_extension_array_dtype (arr ):
1635
1630
return arr .take (indexer , fill_value = fill_value , allow_fill = allow_fill )
1636
1631
1637
- if isinstance (arr , (ABCIndexClass , ABCSeries )):
1638
- arr = arr ._values
1639
-
1632
+ arr = extract_array (arr )
1640
1633
arr = np .asarray (arr )
1641
1634
1642
1635
if indexer is None :
0 commit comments