Skip to content

Commit 0a72afb

Browse files
authored
BUG: SeriesGroupBy.nlargest/smallest inconsistent shape (#42596)
1 parent fcadfb8 commit 0a72afb

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ Groupby/resample/rolling
276276
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
277277
- Bug in :meth:`DataFrame.groupby.rolling.var` would calculate the rolling variance only on the first group (:issue:`42442`)
278278
- Bug in :meth:`GroupBy.shift` that would return the grouping columns if ``fill_value`` was not None (:issue:`41556`)
279+
- Bug in :meth:`SeriesGroupBy.nlargest` and :meth:`SeriesGroupBy.nsmallest` would have an inconsistent index when the input Series was sorted and ``n`` was greater than or equal to all group sizes (:issue:`15272`, :issue:`16345`, :issue:`29129`)
279280
- Bug in :meth:`pandas.DataFrame.ewm`, where non-float64 dtypes were silently failing (:issue:`42452`)
280281
- Bug in :meth:`pandas.DataFrame.rolling` operation along rows (``axis=1``) incorrectly omits columns containing ``float16`` and ``float32`` (:issue:`41779`)
281282

pandas/core/groupby/generic.py

+18
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,24 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None):
870870

871871
return (filled / shifted) - 1
872872

873+
@doc(Series.nlargest)
874+
def nlargest(self, n: int = 5, keep: str = "first"):
875+
f = partial(Series.nlargest, n=n, keep=keep)
876+
data = self._obj_with_exclusions
877+
# Don't change behavior if result index happens to be the same, i.e.
878+
# already ordered and n >= all group sizes.
879+
result = self._python_apply_general(f, data, not_indexed_same=True)
880+
return result
881+
882+
@doc(Series.nsmallest)
883+
def nsmallest(self, n: int = 5, keep: str = "first"):
884+
f = partial(Series.nsmallest, n=n, keep=keep)
885+
data = self._obj_with_exclusions
886+
# Don't change behavior if result index happens to be the same, i.e.
887+
# already ordered and n >= all group sizes.
888+
result = self._python_apply_general(f, data, not_indexed_same=True)
889+
return result
890+
873891

874892
@pin_allowlisted_properties(DataFrame, base.dataframe_apply_allowlist)
875893
class DataFrameGroupBy(GroupBy[DataFrame]):

pandas/core/groupby/groupby.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1275,7 +1275,7 @@ def f(g):
12751275

12761276
@final
12771277
def _python_apply_general(
1278-
self, f: F, data: DataFrame | Series
1278+
self, f: F, data: DataFrame | Series, not_indexed_same: bool | None = None
12791279
) -> DataFrame | Series:
12801280
"""
12811281
Apply function f in python space
@@ -1286,6 +1286,10 @@ def _python_apply_general(
12861286
Function to apply
12871287
data : Series or DataFrame
12881288
Data to apply f to
1289+
not_indexed_same: bool, optional
1290+
When specified, overrides the value of not_indexed_same. Apply behaves
1291+
differently when the result index is equal to the input index, but
1292+
this can be coincidental leading to value-dependent behavior.
12891293
12901294
Returns
12911295
-------
@@ -1294,8 +1298,11 @@ def _python_apply_general(
12941298
"""
12951299
keys, values, mutated = self.grouper.apply(f, data, self.axis)
12961300

1301+
if not_indexed_same is None:
1302+
not_indexed_same = mutated or self.mutated
1303+
12971304
return self._wrap_applied_output(
1298-
data, keys, values, not_indexed_same=mutated or self.mutated
1305+
data, keys, values, not_indexed_same=not_indexed_same
12991306
)
13001307

13011308
@final

pandas/tests/groupby/test_function.py

+17
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,23 @@ def test_nsmallest():
680680
tm.assert_series_equal(gb.nsmallest(3, keep="last"), e)
681681

682682

683+
@pytest.mark.parametrize(
684+
"data, groups",
685+
[([0, 1, 2, 3], [0, 0, 1, 1]), ([0], [0])],
686+
)
687+
@pytest.mark.parametrize("method", ["nlargest", "nsmallest"])
688+
def test_nlargest_and_smallest_noop(data, groups, method):
689+
# GH 15272, GH 16345, GH 29129
690+
# Test nlargest/smallest when it results in a noop,
691+
# i.e. input is sorted and group size <= n
692+
if method == "nlargest":
693+
data = list(reversed(data))
694+
ser = Series(data, name="a")
695+
result = getattr(ser.groupby(groups), method)(n=2)
696+
expected = Series(data, index=MultiIndex.from_arrays([groups, ser.index]), name="a")
697+
tm.assert_series_equal(result, expected)
698+
699+
683700
@pytest.mark.parametrize("func", ["cumprod", "cumsum"])
684701
def test_numpy_compat(func):
685702
# see gh-12811

0 commit comments

Comments
 (0)