diff --git a/pandas/core/sorting.py b/pandas/core/sorting.py index ef69939d6e978..a5d3afd38adb2 100644 --- a/pandas/core/sorting.py +++ b/pandas/core/sorting.py @@ -8,7 +8,9 @@ from pandas.core.dtypes.cast import infer_dtype_from_array from pandas.core.dtypes.common import ( - ensure_int64, ensure_platform_int, is_categorical_dtype, is_list_like) + ensure_int64, ensure_platform_int, is_categorical_dtype, is_list_like, + is_extension_array_dtype) +from pandas.core.dtypes.generic import ABCIndexClass from pandas.core.dtypes.missing import isna import pandas.core.algorithms as algorithms @@ -239,7 +241,9 @@ def nargsort(items, kind='quicksort', ascending=True, na_position='last'): """ # specially handle Categorical - if is_categorical_dtype(items): + if is_extension_array_dtype(items): + if isinstance(items, ABCIndexClass): + items = items._values if na_position not in {'first', 'last'}: raise ValueError('invalid na_position: {!r}'.format(na_position)) diff --git a/pandas/tests/test_sorting.py b/pandas/tests/test_sorting.py index 7528566e8326e..0fbc44ec27a4c 100644 --- a/pandas/tests/test_sorting.py +++ b/pandas/tests/test_sorting.py @@ -9,7 +9,9 @@ from pandas.compat import PY2 -from pandas import DataFrame, MultiIndex, Series, compat, concat, merge +from pandas import ( + Categorical, DataFrame, MultiIndex, Series, compat, + concat, merge, to_datetime) from pandas.core import common as com from pandas.core.sorting import ( decons_group_index, get_group_index, is_int64_overflow_possible, @@ -183,6 +185,22 @@ def test_nargsort(self): exp = list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)) tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False) + @pytest.mark.parametrize('data', [ + Categorical(['a', 'c', 'a', 'b']), + to_datetime([0, 2, 0, 1]).tz_localize('Europe/Brussels')]) + def test_nargsort_extension_array(self, data): + result = nargsort(data) + expected = np.array([0, 2, 3, 1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + def test_nargsort_datetimearray_warning(self, recwarn): + # https://github.com/pandas-dev/pandas/issues/25439 + # can be removed once the FutureWarning for np.array(DTA) is removed + data = to_datetime([0, 2, 0, 1]).tz_localize('Europe/Brussels') + nargsort(data) + msg = "Converting timezone-aware DatetimeArray to timezone-naive" + assert len([w for w in recwarn.list if msg in str(w.message)]) == 0 + class TestMerge(object):