Skip to content

Commit a2686c6

Browse files
jorisvandenbosschejreback
authored andcommitted
CLN: handle EAs and fast path (no bounds checking) in safe_sort (#25696)
1 parent 6df1219 commit a2686c6

File tree

5 files changed

+93
-48
lines changed

5 files changed

+93
-48
lines changed

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ Other
428428
^^^^^
429429

430430
- Removed unused C functions from vendored UltraJSON implementation (:issue:`26198`)
431+
- Bug in :func:`factorize` when passing an ``ExtensionArray`` with a custom ``na_sentinel`` (:issue:`25696`).
431432

432433

433434
.. _whatsnew_0.250.contributors:

pandas/core/algorithms.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -617,22 +617,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
617617

618618
if sort and len(uniques) > 0:
619619
from pandas.core.sorting import safe_sort
620-
if na_sentinel == -1:
621-
# GH-25409 take_1d only works for na_sentinels of -1
622-
try:
623-
order = uniques.argsort()
624-
order2 = order.argsort()
625-
labels = take_1d(order2, labels, fill_value=na_sentinel)
626-
uniques = uniques.take(order)
627-
except TypeError:
628-
# Mixed types, where uniques.argsort fails.
629-
uniques, labels = safe_sort(uniques, labels,
630-
na_sentinel=na_sentinel,
631-
assume_unique=True)
632-
else:
633-
uniques, labels = safe_sort(uniques, labels,
634-
na_sentinel=na_sentinel,
635-
assume_unique=True)
620+
uniques, labels = safe_sort(uniques, labels, na_sentinel=na_sentinel,
621+
assume_unique=True, verify=False)
636622

637623
uniques = _reconstruct_data(uniques, dtype, original)
638624

pandas/core/sorting.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from pandas.core.dtypes.cast import infer_dtype_from_array
1010
from pandas.core.dtypes.common import (
11-
ensure_int64, ensure_platform_int, is_categorical_dtype, is_list_like)
11+
ensure_int64, ensure_platform_int, is_categorical_dtype,
12+
is_extension_array_dtype, is_list_like)
1213
from pandas.core.dtypes.missing import isna
1314

1415
import pandas.core.algorithms as algorithms
@@ -403,7 +404,8 @@ def _reorder_by_uniques(uniques, labels):
403404
return uniques, labels
404405

405406

406-
def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
407+
def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False,
408+
verify=True):
407409
"""
408410
Sort ``values`` and reorder corresponding ``labels``.
409411
``values`` should be unique if ``labels`` is not None.
@@ -424,6 +426,12 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
424426
assume_unique : bool, default False
425427
When True, ``values`` are assumed to be unique, which can speed up
426428
the calculation. Ignored when ``labels`` is None.
429+
verify : bool, default True
430+
Check if labels are out of bound for the values and put out of bound
431+
labels equal to na_sentinel. If ``verify=False``, it is assumed there
432+
are no out of bound labels. Ignored when ``labels`` is None.
433+
434+
.. versionadded:: 0.25.0
427435
428436
Returns
429437
-------
@@ -445,8 +453,8 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
445453
raise TypeError("Only list-like objects are allowed to be passed to"
446454
"safe_sort as values")
447455

448-
if not isinstance(values, np.ndarray):
449-
456+
if (not isinstance(values, np.ndarray)
457+
and not is_extension_array_dtype(values)):
450458
# don't convert to string types
451459
dtype, _ = infer_dtype_from_array(values)
452460
values = np.asarray(values, dtype=dtype)
@@ -460,7 +468,8 @@ def sort_mixed(values):
460468
return np.concatenate([nums, np.asarray(strs, dtype=object)])
461469

462470
sorter = None
463-
if lib.infer_dtype(values, skipna=False) == 'mixed-integer':
471+
if (not is_extension_array_dtype(values)
472+
and lib.infer_dtype(values, skipna=False) == 'mixed-integer'):
464473
# unorderable in py3 if mixed str/int
465474
ordered = sort_mixed(values)
466475
else:
@@ -493,15 +502,26 @@ def sort_mixed(values):
493502
t.map_locations(values)
494503
sorter = ensure_platform_int(t.lookup(ordered))
495504

496-
reverse_indexer = np.empty(len(sorter), dtype=np.int_)
497-
reverse_indexer.put(sorter, np.arange(len(sorter)))
498-
499-
mask = (labels < -len(values)) | (labels >= len(values)) | \
500-
(labels == na_sentinel)
501-
502-
# (Out of bound indices will be masked with `na_sentinel` next, so we may
503-
# deal with them here without performance loss using `mode='wrap'`.)
504-
new_labels = reverse_indexer.take(labels, mode='wrap')
505-
np.putmask(new_labels, mask, na_sentinel)
505+
if na_sentinel == -1:
506+
# take_1d is faster, but only works for na_sentinels of -1
507+
order2 = sorter.argsort()
508+
new_labels = algorithms.take_1d(order2, labels, fill_value=-1)
509+
if verify:
510+
mask = (labels < -len(values)) | (labels >= len(values))
511+
else:
512+
mask = None
513+
else:
514+
reverse_indexer = np.empty(len(sorter), dtype=np.int_)
515+
reverse_indexer.put(sorter, np.arange(len(sorter)))
516+
# Out of bound indices will be masked with `na_sentinel` next, so we
517+
# may deal with them here without performance loss using `mode='wrap'`
518+
new_labels = reverse_indexer.take(labels, mode='wrap')
519+
520+
mask = labels == na_sentinel
521+
if verify:
522+
mask = mask | (labels < -len(values)) | (labels >= len(values))
523+
524+
if mask is not None:
525+
np.putmask(new_labels, mask, na_sentinel)
506526

507527
return ordered, ensure_platform_int(new_labels)

pandas/tests/test_algos.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pandas.core.algorithms as algos
2222
from pandas.core.arrays import DatetimeArray
2323
import pandas.core.common as com
24+
from pandas.core.sorting import safe_sort
2425
import pandas.util.testing as tm
2526
from pandas.util.testing import assert_almost_equal
2627

@@ -324,18 +325,26 @@ def test_parametrized_factorize_na_value(self, data, na_value):
324325

325326
@pytest.mark.parametrize('sort', [True, False])
326327
@pytest.mark.parametrize('na_sentinel', [-1, -10, 100])
327-
def test_factorize_na_sentinel(self, sort, na_sentinel):
328-
data = np.array(['b', 'a', None, 'b'], dtype=object)
328+
@pytest.mark.parametrize('data, uniques', [
329+
(np.array(['b', 'a', None, 'b'], dtype=object),
330+
np.array(['b', 'a'], dtype=object)),
331+
(pd.array([2, 1, np.nan, 2], dtype='Int64'),
332+
pd.array([2, 1], dtype='Int64'))],
333+
ids=['numpy_array', 'extension_array'])
334+
def test_factorize_na_sentinel(self, sort, na_sentinel, data, uniques):
329335
labels, uniques = algos.factorize(data, sort=sort,
330336
na_sentinel=na_sentinel)
331337
if sort:
332338
expected_labels = np.array([1, 0, na_sentinel, 1], dtype=np.intp)
333-
expected_uniques = np.array(['a', 'b'], dtype=object)
339+
expected_uniques = safe_sort(uniques)
334340
else:
335341
expected_labels = np.array([0, 1, na_sentinel, 0], dtype=np.intp)
336-
expected_uniques = np.array(['b', 'a'], dtype=object)
342+
expected_uniques = uniques
337343
tm.assert_numpy_array_equal(labels, expected_labels)
338-
tm.assert_numpy_array_equal(uniques, expected_uniques)
344+
if isinstance(data, np.ndarray):
345+
tm.assert_numpy_array_equal(uniques, expected_uniques)
346+
else:
347+
tm.assert_extension_array_equal(uniques, expected_uniques)
339348

340349

341350
class TestUnique:

pandas/tests/test_sorting.py

+41-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from numpy import nan
77
import pytest
88

9-
from pandas import DataFrame, MultiIndex, Series, concat, merge, to_datetime
9+
from pandas import (
10+
DataFrame, MultiIndex, Series, array, concat, merge, to_datetime)
1011
from pandas.core import common as com
1112
from pandas.core.sorting import (
1213
decons_group_index, get_group_index, is_int64_overflow_possible,
@@ -358,34 +359,43 @@ def test_basic_sort(self):
358359
expected = np.array([])
359360
tm.assert_numpy_array_equal(result, expected)
360361

361-
def test_labels(self):
362+
@pytest.mark.parametrize('verify', [True, False])
363+
def test_labels(self, verify):
362364
values = [3, 1, 2, 0, 4]
363365
expected = np.array([0, 1, 2, 3, 4])
364366

365367
labels = [0, 1, 1, 2, 3, 0, -1, 4]
366-
result, result_labels = safe_sort(values, labels)
368+
result, result_labels = safe_sort(values, labels, verify=verify)
367369
expected_labels = np.array([3, 1, 1, 2, 0, 3, -1, 4], dtype=np.intp)
368370
tm.assert_numpy_array_equal(result, expected)
369371
tm.assert_numpy_array_equal(result_labels, expected_labels)
370372

371373
# na_sentinel
372374
labels = [0, 1, 1, 2, 3, 0, 99, 4]
373-
result, result_labels = safe_sort(values, labels,
374-
na_sentinel=99)
375+
result, result_labels = safe_sort(values, labels, na_sentinel=99,
376+
verify=verify)
375377
expected_labels = np.array([3, 1, 1, 2, 0, 3, 99, 4], dtype=np.intp)
376378
tm.assert_numpy_array_equal(result, expected)
377379
tm.assert_numpy_array_equal(result_labels, expected_labels)
378380

379-
# out of bound indices
380-
labels = [0, 101, 102, 2, 3, 0, 99, 4]
381-
result, result_labels = safe_sort(values, labels)
382-
expected_labels = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp)
381+
labels = []
382+
result, result_labels = safe_sort(values, labels, verify=verify)
383+
expected_labels = np.array([], dtype=np.intp)
383384
tm.assert_numpy_array_equal(result, expected)
384385
tm.assert_numpy_array_equal(result_labels, expected_labels)
385386

386-
labels = []
387-
result, result_labels = safe_sort(values, labels)
388-
expected_labels = np.array([], dtype=np.intp)
387+
@pytest.mark.parametrize('na_sentinel', [-1, 99])
388+
def test_labels_out_of_bound(self, na_sentinel):
389+
values = [3, 1, 2, 0, 4]
390+
expected = np.array([0, 1, 2, 3, 4])
391+
392+
# out of bound indices
393+
labels = [0, 101, 102, 2, 3, 0, 99, 4]
394+
result, result_labels = safe_sort(
395+
values, labels, na_sentinel=na_sentinel)
396+
expected_labels = np.array(
397+
[3, na_sentinel, na_sentinel, 2, 0, 3, na_sentinel, 4],
398+
dtype=np.intp)
389399
tm.assert_numpy_array_equal(result, expected)
390400
tm.assert_numpy_array_equal(result_labels, expected_labels)
391401

@@ -430,3 +440,22 @@ def test_exceptions(self):
430440
with pytest.raises(ValueError,
431441
match="values should be unique"):
432442
safe_sort(values=[0, 1, 2, 1], labels=[0, 1])
443+
444+
def test_extension_array(self):
445+
# a = array([1, 3, np.nan, 2], dtype='Int64')
446+
a = array([1, 3, 2], dtype='Int64')
447+
result = safe_sort(a)
448+
# expected = array([1, 2, 3, np.nan], dtype='Int64')
449+
expected = array([1, 2, 3], dtype='Int64')
450+
tm.assert_extension_array_equal(result, expected)
451+
452+
@pytest.mark.parametrize('verify', [True, False])
453+
@pytest.mark.parametrize('na_sentinel', [-1, 99])
454+
def test_extension_array_labels(self, verify, na_sentinel):
455+
a = array([1, 3, 2], dtype='Int64')
456+
result, labels = safe_sort(a, [0, 1, na_sentinel, 2],
457+
na_sentinel=na_sentinel, verify=verify)
458+
expected_values = array([1, 2, 3], dtype='Int64')
459+
expected_labels = np.array([0, 2, na_sentinel, 1], dtype=np.intp)
460+
tm.assert_extension_array_equal(result, expected_values)
461+
tm.assert_numpy_array_equal(labels, expected_labels)

0 commit comments

Comments
 (0)