11
11
from pandas .core .dtypes .common import (
12
12
ensure_int64 , ensure_platform_int , is_categorical_dtype , is_list_like )
13
13
from pandas .core .dtypes .missing import isna
14
+ from pandas .core .dtypes .generic import ABCExtensionArray
14
15
15
16
import pandas .core .algorithms as algorithms
16
17
@@ -404,7 +405,8 @@ def _reorder_by_uniques(uniques, labels):
404
405
return uniques , labels
405
406
406
407
407
- def safe_sort (values , labels = None , na_sentinel = - 1 , assume_unique = False ):
408
+ def safe_sort (values , labels = None , na_sentinel = - 1 , assume_unique = False ,
409
+ check_outofbounds = True ):
408
410
"""
409
411
Sort ``values`` and reorder corresponding ``labels``.
410
412
``values`` should be unique if ``labels`` is not None.
@@ -425,6 +427,10 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
425
427
assume_unique : bool, default False
426
428
When True, ``values`` are assumed to be unique, which can speed up
427
429
the calculation. Ignored when ``labels`` is None.
430
+ check_outofbounds : bool, default True
431
+ Check if labels are out of bound for the values and put out of bound
432
+ labels equal to na_sentinel. If ``check_outofbounds=False``, it is
433
+ assumed there are no out of bound labels.
428
434
429
435
Returns
430
436
-------
@@ -446,8 +452,8 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
446
452
raise TypeError ("Only list-like objects are allowed to be passed to"
447
453
"safe_sort as values" )
448
454
449
- if not isinstance (values , np .ndarray ):
450
-
455
+ if ( not isinstance (values , np .ndarray )
456
+ and not isinstance ( values , ABCExtensionArray )):
451
457
# don't convert to string types
452
458
dtype , _ = infer_dtype_from_array (values )
453
459
values = np .asarray (values , dtype = dtype )
@@ -461,7 +467,8 @@ def sort_mixed(values):
461
467
return np .concatenate ([nums , np .asarray (strs , dtype = object )])
462
468
463
469
sorter = None
464
- if PY3 and lib .infer_dtype (values , skipna = False ) == 'mixed-integer' :
470
+ if (PY3 and not isinstance (values , ABCExtensionArray )
471
+ and lib .infer_dtype (values , skipna = False ) == 'mixed-integer' ):
465
472
# unorderable in py3 if mixed str/int
466
473
ordered = sort_mixed (values )
467
474
else :
@@ -494,15 +501,26 @@ def sort_mixed(values):
494
501
t .map_locations (values )
495
502
sorter = ensure_platform_int (t .lookup (ordered ))
496
503
497
- reverse_indexer = np .empty (len (sorter ), dtype = np .int_ )
498
- reverse_indexer .put (sorter , np .arange (len (sorter )))
499
-
500
- mask = (labels < - len (values )) | (labels >= len (values )) | \
501
- (labels == na_sentinel )
502
-
503
- # (Out of bound indices will be masked with `na_sentinel` next, so we may
504
- # deal with them here without performance loss using `mode='wrap'`.)
505
- new_labels = reverse_indexer .take (labels , mode = 'wrap' )
506
- np .putmask (new_labels , mask , na_sentinel )
504
+ if na_sentinel == - 1 :
505
+ # take_1d is faster, but only works for na_sentinels of -1
506
+ order2 = sorter .argsort ()
507
+ new_labels = algorithms .take_1d (order2 , labels , fill_value = - 1 )
508
+ if check_outofbounds :
509
+ mask = (labels < - len (values )) | (labels >= len (values ))
510
+ else :
511
+ mask = None
512
+ else :
513
+ reverse_indexer = np .empty (len (sorter ), dtype = np .int_ )
514
+ reverse_indexer .put (sorter , np .arange (len (sorter )))
515
+ # Out of bound indices will be masked with `na_sentinel` next, so we
516
+ # may deal with them here without performance loss using `mode='wrap'`
517
+ new_labels = reverse_indexer .take (labels , mode = 'wrap' )
518
+
519
+ mask = labels == na_sentinel
520
+ if check_outofbounds :
521
+ mask = mask | (labels < - len (values )) | (labels >= len (values ))
522
+
523
+ if mask is not None :
524
+ np .putmask (new_labels , mask , na_sentinel )
507
525
508
526
return ordered , ensure_platform_int (new_labels )
0 commit comments