diff --git a/doc/source/whatsnew/v0.22.0.txt b/doc/source/whatsnew/v0.22.0.txt index ab7f18bce47d3..53d8aa5946845 100644 --- a/doc/source/whatsnew/v0.22.0.txt +++ b/doc/source/whatsnew/v0.22.0.txt @@ -138,6 +138,8 @@ Other Enhancements - :func:`Series` / :func:`DataFrame` tab completion also returns identifiers in the first level of a :func:`MultiIndex`. (:issue:`16326`) - :func:`read_excel()` has gained the ``nrows`` parameter (:issue:`16645`) - :func:``DataFrame.to_json`` and ``Series.to_json`` now accept an ``index`` argument which allows the user to exclude the index from the JSON output (:issue:`17394`) +- :func:`Series` / :func:`DataFrame` methods :func:`nlargest` / :func:`nsmallest` now accept the value 'all' for the `keep` argument. This keeps all ties for the nth largests/smallest value (:issue:`16818`). + .. _whatsnew_0220.api_breaking: diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 0ceb8966fd3c8..099fec74d266c 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -910,8 +910,8 @@ def __init__(self, obj, n, keep): self.n = n self.keep = keep - if self.keep not in ('first', 'last'): - raise ValueError('keep must be either "first", "last"') + if self.keep not in ('first', 'last', 'all'): + raise ValueError('keep must be either "first", "last", or "all"') def nlargest(self): return self.compute('nlargest') @@ -979,7 +979,11 @@ def compute(self, method): kth_val = algos.kth_smallest(arr.copy(), n - 1) ns, = np.nonzero(arr <= kth_val) - inds = ns[arr[ns].argsort(kind='mergesort')][:n] + inds = ns[arr[ns].argsort(kind='mergesort')] + + if self.keep != 'all': + inds = inds[:n] + if self.keep == 'last': # reverse indices inds = narr - 1 - inds diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5f323d0f040bc..d1441da0d810f 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3769,10 +3769,13 @@ def nlargest(self, n, columns, keep='first'): Number of items to retrieve columns : list or str Column name or names to order by - keep : {'first', 'last'}, default 'first' + keep : {'first', 'last', 'all'}, default 'first' Where there are duplicate values: - - ``first`` : take the first occurrence. - - ``last`` : take the last occurrence. + - 'first' : take the first occurrence. + - 'last' : take the last occurrence. + - 'all' : keep all ties of nth largest value. + + .. versionadded:: 0.22.0 Returns ------- @@ -3780,14 +3783,28 @@ def nlargest(self, n, columns, keep='first'): Examples -------- - >>> df = DataFrame({'a': [1, 10, 8, 11, -1], - ... 'b': list('abdce'), - ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]}) - >>> df.nlargest(3, 'a') + >>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2], + ... 'b': list('abdcef'), + ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]}) + + >>> df.nlargest(3, 'a', keep='first') + a b c + 3 11 c 3 + 1 10 b 2 + 2 8 d NaN + + >>> df.nlargest(3, 'a', keep='last') + a b c + 3 11 c 3 + 1 10 b 2 + 4 8 e 4 + + >>> df.nlargest(3, 'a', keep='all') a b c 3 11 c 3 1 10 b 2 2 8 d NaN + 4 8 e 4 """ return algorithms.SelectNFrame(self, n=n, @@ -3804,10 +3821,13 @@ def nsmallest(self, n, columns, keep='first'): Number of items to retrieve columns : list or str Column name or names to order by - keep : {'first', 'last'}, default 'first' + keep : {'first', 'last', 'all'}, default 'first' Where there are duplicate values: - - ``first`` : take the first occurrence. - - ``last`` : take the last occurrence. + - 'first' : take the first occurrence. + - 'last' : take the last occurrence. + - 'all' : keep all ties of nth smallest value. + + .. versionadded:: 0.22.0 Returns ------- @@ -3815,14 +3835,28 @@ def nsmallest(self, n, columns, keep='first'): Examples -------- - >>> df = DataFrame({'a': [1, 10, 8, 11, -1], - ... 'b': list('abdce'), - ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]}) - >>> df.nsmallest(3, 'a') - a b c - 4 -1 e 4 - 0 1 a 1 - 2 8 d NaN + >>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2], + ... 'b': list('abdcef'), + ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]}) + + >>> df.nsmallest(3, 'a', keep='first') + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 2 8 d NaN + + >>> df.nsmallest(3, 'a', keep='last') + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 4 8 e 4.0 + + >>> df.nsmallest(3, 'a', keep='all') + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 2 8 d NaN + 4 8 e 4.0 """ return algorithms.SelectNFrame(self, n=n, diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 4bba6d7601ae8..c038d76879ce1 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -2202,6 +2202,22 @@ def test_n_duplicate_index(self, df_duplicates, n, order): expected = df.sort_values(order, ascending=False).head(n) tm.assert_frame_equal(result, expected) + def test_keep_all_ties(self): + # GH 16818 + df = pd.DataFrame({'a': [5, 4, 4, 2, 3, 3, 3, 3], + 'b': [10, 9, 8, 7, 5, 50, 10, 20]}) + result = df.nlargest(4, 'a', keep='all') + expected = pd.DataFrame({'a': {0: 5, 1: 4, 2: 4, 4: 3, + 5: 3, 6: 3, 7: 3}, + 'b': {0: 10, 1: 9, 2: 8, 4: 5, + 5: 50, 6: 10, 7: 20}}) + tm.assert_frame_equal(result, expected) + + result = df.nsmallest(2, 'a', keep='all') + expected = pd.DataFrame({'a': {3: 2, 4: 3, 5: 3, 6: 3, 7: 3}, + 'b': {3: 7, 4: 5, 5: 50, 6: 10, 7: 20}}) + tm.assert_frame_equal(result, expected) + def test_series_broadcasting(self): # smoke test for numpy warnings # GH 16378, GH 16306 diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index 289b5c01c1263..2e4a84f8bcd6b 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -1867,6 +1867,18 @@ def test_n(self, n): expected = s.sort_values().head(n) assert_series_equal(result, expected) + def test_keep_all_ties(self): + # GH 16818 + s = Series([10, 9, 8, 7, 7, 7, 7, 6]) + result = s.nlargest(4, keep='all') + expected = Series([10, 9, 8, 7, 7, 7, 7]) + print(result, expected) + assert_series_equal(result, expected) + + result = s.nsmallest(2, keep='all') + expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6]) + assert_series_equal(result, expected) + class TestCategoricalSeriesAnalytics(object):