diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index a62d290277443..99ef76e0f4812 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -12,6 +12,7 @@ from pandas.types.common import (is_unsigned_integer_dtype, is_signed_integer_dtype, is_integer_dtype, + is_complex_dtype, is_categorical_dtype, is_extension_type, is_datetimetz, @@ -40,6 +41,44 @@ from pandas._libs.tslib import iNaT +# --------------- # +# dtype access # +# --------------- # + +def _ensure_data_view(values): + """ + helper routine to ensure that our data is of the correct + input dtype for lower-level routines + + Parameters + ---------- + values : array-like + """ + + if needs_i8_conversion(values): + values = values.view(np.int64) + elif is_period_arraylike(values): + from pandas.tseries.period import PeriodIndex + values = PeriodIndex(values).asi8 + elif is_categorical_dtype(values): + values = values.values.codes + elif isinstance(values, (ABCSeries, ABCIndex)): + values = values.values + + if is_signed_integer_dtype(values): + values = _ensure_int64(values) + elif is_unsigned_integer_dtype(values): + values = _ensure_uint64(values) + elif is_complex_dtype(values): + values = _ensure_float64(values) + elif is_float_dtype(values): + values = _ensure_float64(values) + else: + values = _ensure_object(values) + + return values + + # --------------- # # top-level algos # # --------------- # @@ -867,9 +906,7 @@ def nsmallest(arr, n, keep='first'): narr = len(arr) n = min(n, narr) - sdtype = str(arr.dtype) - arr = arr.view(_dtype_map.get(sdtype, sdtype)) - + arr = _ensure_data_view(arr) kth_val = algos.kth_smallest(arr.copy(), n - 1) return _finalize_nsmallest(arr, kth_val, n, keep, narr) @@ -880,8 +917,7 @@ def nlargest(arr, n, keep='first'): Note: Fails silently with NaN. """ - sdtype = str(arr.dtype) - arr = arr.view(_dtype_map.get(sdtype, sdtype)) + arr = _ensure_data_view(arr) return nsmallest(-arr, n, keep=keep) @@ -910,9 +946,10 @@ def select_n_series(series, n, keep, method): nordered : Series """ dtype = series.dtype - if not issubclass(dtype.type, (np.integer, np.floating, np.datetime64, - np.timedelta64)): - raise TypeError("Cannot use method %r with dtype %s" % (method, dtype)) + if not ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or + needs_i8_conversion(dtype)): + raise TypeError("Cannot use method '{method}' with " + "dtype {dtype}".format(method=method, dtype=dtype)) if keep not in ('first', 'last'): raise ValueError('keep must be either "first", "last"') @@ -964,9 +1001,6 @@ def _finalize_nsmallest(arr, kth_val, n, keep, narr): return inds -_dtype_map = {'datetime64[ns]': 'int64', 'timedelta64[ns]': 'int64'} - - # ------- # # helpers # # ------- # diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index b747a680c17dd..732142f1bce9a 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -1381,80 +1381,6 @@ def test_is_monotonic(self): self.assertFalse(s.is_monotonic) self.assertTrue(s.is_monotonic_decreasing) - def test_nsmallest_nlargest(self): - # float, int, datetime64 (use i8), timedelts64 (same), - # object that are numbers, object that are strings - - base = [3, 2, 1, 2, 5] - - s_list = [ - Series(base, dtype='int8'), - Series(base, dtype='int16'), - Series(base, dtype='int32'), - Series(base, dtype='int64'), - Series(base, dtype='float32'), - Series(base, dtype='float64'), - Series(base, dtype='uint8'), - Series(base, dtype='uint16'), - Series(base, dtype='uint32'), - Series(base, dtype='uint64'), - Series(base).astype('timedelta64[ns]'), - Series(pd.to_datetime(['2003', '2002', '2001', '2002', '2005'])), - ] - - raising = [ - Series([3., 2, 1, 2, '5'], dtype='object'), - Series([3., 2, 1, 2, 5], dtype='object'), - # not supported on some archs - # Series([3., 2, 1, 2, 5], dtype='complex256'), - Series([3., 2, 1, 2, 5], dtype='complex128'), - ] - - for r in raising: - dt = r.dtype - msg = "Cannot use method 'n(larg|small)est' with dtype %s" % dt - args = 2, len(r), 0, -1 - methods = r.nlargest, r.nsmallest - for method, arg in product(methods, args): - with tm.assertRaisesRegexp(TypeError, msg): - method(arg) - - for s in s_list: - - assert_series_equal(s.nsmallest(2), s.iloc[[2, 1]]) - assert_series_equal(s.nsmallest(2, keep='last'), s.iloc[[2, 3]]) - - empty = s.iloc[0:0] - assert_series_equal(s.nsmallest(0), empty) - assert_series_equal(s.nsmallest(-1), empty) - assert_series_equal(s.nlargest(0), empty) - assert_series_equal(s.nlargest(-1), empty) - - assert_series_equal(s.nsmallest(len(s)), s.sort_values()) - assert_series_equal(s.nsmallest(len(s) + 1), s.sort_values()) - assert_series_equal(s.nlargest(len(s)), s.iloc[[4, 0, 1, 3, 2]]) - assert_series_equal(s.nlargest(len(s) + 1), - s.iloc[[4, 0, 1, 3, 2]]) - - s = Series([3., np.nan, 1, 2, 5]) - assert_series_equal(s.nlargest(), s.iloc[[4, 0, 3, 2]]) - assert_series_equal(s.nsmallest(), s.iloc[[2, 3, 0, 4]]) - - msg = 'keep must be either "first", "last"' - with tm.assertRaisesRegexp(ValueError, msg): - s.nsmallest(keep='invalid') - with tm.assertRaisesRegexp(ValueError, msg): - s.nlargest(keep='invalid') - - # GH 13412 - s = Series([1, 4, 3, 2], index=[0, 0, 1, 1]) - result = s.nlargest(3) - expected = s.sort_values(ascending=False).head(3) - assert_series_equal(result, expected) - result = s.nsmallest(3) - expected = s.sort_values().head(3) - assert_series_equal(result, expected) - def test_sort_index_level(self): mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC')) s = Series([1, 2], mi) @@ -1729,3 +1655,109 @@ def test_value_counts_categorical_not_ordered(self): index=exp_idx, name='xxx') tm.assert_series_equal(s.value_counts(normalize=True), exp) tm.assert_series_equal(idx.value_counts(normalize=True), exp) + + +@pytest.fixture +def s_main_dtypes(): + df = pd.DataFrame( + {'datetime': pd.to_datetime(['2003', '2002', + '2001', '2002', + '2005']), + 'datetimetz': pd.to_datetime( + ['2003', '2002', + '2001', '2002', + '2005']).tz_localize('US/Eastern'), + 'timedelta': pd.to_timedelta(['3d', '2d', '1d', + '2d', '5d'])}) + + for dtype in ['int8', 'int16', 'int32', 'int64', + 'float32', 'float64', + 'uint8', 'uint16', 'uint32', 'uint64']: + df[dtype] = Series([3, 2, 1, 2, 5], dtype=dtype) + + return df + + +class TestNLargestNSmallest(object): + + @pytest.mark.parametrize( + "r", [Series([3., 2, 1, 2, '5'], dtype='object'), + Series([3., 2, 1, 2, 5], dtype='object'), + # not supported on some archs + # Series([3., 2, 1, 2, 5], dtype='complex256'), + Series([3., 2, 1, 2, 5], dtype='complex128'), + Series(list('abcde'), dtype='category'), + Series(list('abcde'))]) + def test_error(self, r): + dt = r.dtype + msg = ("Cannot use method 'n(larg|small)est' with " + "dtype {dt}".format(dt=dt)) + args = 2, len(r), 0, -1 + methods = r.nlargest, r.nsmallest + for method, arg in product(methods, args): + with tm.assertRaisesRegexp(TypeError, msg): + method(arg) + + @pytest.mark.parametrize( + "s", + [v for k, v in s_main_dtypes().iteritems()]) + def test_nsmallest_nlargest(self, s): + # float, int, datetime64 (use i8), timedelts64 (same), + # object that are numbers, object that are strings + + assert_series_equal(s.nsmallest(2), s.iloc[[2, 1]]) + assert_series_equal(s.nsmallest(2, keep='last'), s.iloc[[2, 3]]) + + empty = s.iloc[0:0] + assert_series_equal(s.nsmallest(0), empty) + assert_series_equal(s.nsmallest(-1), empty) + assert_series_equal(s.nlargest(0), empty) + assert_series_equal(s.nlargest(-1), empty) + + assert_series_equal(s.nsmallest(len(s)), s.sort_values()) + assert_series_equal(s.nsmallest(len(s) + 1), s.sort_values()) + assert_series_equal(s.nlargest(len(s)), s.iloc[[4, 0, 1, 3, 2]]) + assert_series_equal(s.nlargest(len(s) + 1), + s.iloc[[4, 0, 1, 3, 2]]) + + def test_misc(self): + + s = Series([3., np.nan, 1, 2, 5]) + assert_series_equal(s.nlargest(), s.iloc[[4, 0, 3, 2]]) + assert_series_equal(s.nsmallest(), s.iloc[[2, 3, 0, 4]]) + + msg = 'keep must be either "first", "last"' + with tm.assertRaisesRegexp(ValueError, msg): + s.nsmallest(keep='invalid') + with tm.assertRaisesRegexp(ValueError, msg): + s.nlargest(keep='invalid') + + # GH 15297 + s = Series([1] * 5, index=[1, 2, 3, 4, 5]) + expected_first = Series([1] * 3, index=[1, 2, 3]) + expected_last = Series([1] * 3, index=[5, 4, 3]) + + result = s.nsmallest(3) + assert_series_equal(result, expected_first) + + result = s.nsmallest(3, keep='last') + assert_series_equal(result, expected_last) + + result = s.nlargest(3) + assert_series_equal(result, expected_first) + + result = s.nlargest(3, keep='last') + assert_series_equal(result, expected_last) + + @pytest.mark.parametrize('n', range(1, 5)) + def test_n(self, n): + + # GH 13412 + s = Series([1, 4, 3, 2], index=[0, 0, 1, 1]) + result = s.nlargest(n) + expected = s.sort_values(ascending=False).head(n) + assert_series_equal(result, expected) + + result = s.nsmallest(n) + expected = s.sort_values().head(n) + assert_series_equal(result, expected)