Skip to content

BUG: SeriesGroupBy.nlargest/smallest inconsistent shape #42596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ Groupby/resample/rolling
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
- Bug in :meth:`DataFrame.groupby.rolling.var` would calculate the rolling variance only on the first group (:issue:`42442`)
- Bug in :meth:`GroupBy.shift` that would return the grouping columns if ``fill_value`` was not None (:issue:`41556`)
- Bug in :meth:`SeriesGroupBy.nlargest` and :meth:`SeriesGroupBy.nsmallest` would have an inconsistent index when the input Series was sorted and ``n`` was greater than or equal to all group sizes (:issue:`15272`, :issue:`16345`, :issue:`29129`)
- Bug in :meth:`pandas.DataFrame.ewm`, where non-float64 dtypes were silently failing (:issue:`42452`)

Reshaping
Expand Down
18 changes: 18 additions & 0 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,24 @@ def pct_change(self, periods=1, fill_method="pad", limit=None, freq=None):

return (filled / shifted) - 1

@doc(Series.nlargest)
def nlargest(self, n: int = 5, keep: str = "first"):
f = partial(Series.nlargest, n=n, keep=keep)
data = self._obj_with_exclusions
# Don't change behavior if result index happens to be the same, i.e.
# already ordered and n >= all group sizes.
result = self._python_apply_general(f, data, not_indexed_same=True)
return result

@doc(Series.nsmallest)
def nsmallest(self, n: int = 5, keep: str = "first"):
f = partial(Series.nsmallest, n=n, keep=keep)
data = self._obj_with_exclusions
# Don't change behavior if result index happens to be the same, i.e.
# already ordered and n >= all group sizes.
result = self._python_apply_general(f, data, not_indexed_same=True)
return result


@pin_allowlisted_properties(DataFrame, base.dataframe_apply_allowlist)
class DataFrameGroupBy(GroupBy[DataFrame]):
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,7 @@ def f(g):

@final
def _python_apply_general(
self, f: F, data: DataFrame | Series
self, f: F, data: DataFrame | Series, not_indexed_same: bool | None = None
) -> DataFrame | Series:
"""
Apply function f in python space
Expand All @@ -1286,6 +1286,10 @@ def _python_apply_general(
Function to apply
data : Series or DataFrame
Data to apply f to
not_indexed_same: bool, optional
When specified, overrides the value of not_indexed_same. Apply behaves
differently when the result index is equal to the input index, but
this can be coincidental leading to value-dependent behavior.

Returns
-------
Expand All @@ -1294,8 +1298,11 @@ def _python_apply_general(
"""
keys, values, mutated = self.grouper.apply(f, data, self.axis)

if not_indexed_same is None:
not_indexed_same = mutated or self.mutated

return self._wrap_applied_output(
data, keys, values, not_indexed_same=mutated or self.mutated
data, keys, values, not_indexed_same=not_indexed_same
)

@final
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/groupby/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,23 @@ def test_nsmallest():
tm.assert_series_equal(gb.nsmallest(3, keep="last"), e)


@pytest.mark.parametrize(
"data, groups",
[([0, 1, 2, 3], [0, 0, 1, 1]), ([0], [0])],
)
@pytest.mark.parametrize("method", ["nlargest", "nsmallest"])
def test_nlargest_and_smallest_noop(data, groups, method):
# GH 15272, GH 16345, GH 29129
# Test nlargest/smallest when it results in a noop,
# i.e. input is sorted and group size <= n
if method == "nlargest":
data = list(reversed(data))
ser = Series(data, name="a")
result = getattr(ser.groupby(groups), method)(n=2)
expected = Series(data, index=MultiIndex.from_arrays([groups, ser.index]), name="a")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("func", ["cumprod", "cumsum"])
def test_numpy_compat(func):
# see gh-12811
Expand Down