Skip to content

BUG: Fix BaseWindowGroupby.aggregate where as_index is ignored #54973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ Bug fixes
~~~~~~~~~
- Bug in :class:`AbstractHolidayCalendar` where timezone data was not propagated when computing holiday observances (:issue:`54580`)
- Bug in :class:`pandas.core.window.Rolling` where duplicate datetimelike indexes are treated as consecutive rather than equal with ``closed='left'`` and ``closed='neither'`` (:issue:`20712`)
- Bug in :meth:`DataFrameGroupBy.rolling.agg` where ``as_index`` is ignored with list-like and dictionary-like ``func`` parameters (:issue:`31007`)

Categorical
^^^^^^^^^^^
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,11 @@ def _cov(X, Y):
)


class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
# error: Definition of "agg" in base class "BaseWindowGroupby" is
# incompatible with definition in base class "ExponentialMovingWindow"
class ExponentialMovingWindowGroupby( # type: ignore[misc]
BaseWindowGroupby, ExponentialMovingWindow
):
"""
Provide an exponential moving window groupby implementation.
"""
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/window/expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,9 @@ def corr(
)


class ExpandingGroupby(BaseWindowGroupby, Expanding):
# error: Definition of "agg" in base class "BaseWindowGroupby" is
# incompatible with definition in base class "Expanding"
class ExpandingGroupby(BaseWindowGroupby, Expanding): # type: ignore[misc]
"""
Provide a expanding groupby implementation.
"""
Expand Down
14 changes: 13 additions & 1 deletion pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from pandas.core.dtypes.common import (
ensure_float64,
is_bool,
is_dict_like,
is_integer,
is_list_like,
is_numeric_dtype,
needs_i8_conversion,
)
Expand Down Expand Up @@ -875,6 +877,14 @@ def _gotitem(self, key, ndim, subset=None):
subset = self.obj.set_index(self._on)
return super()._gotitem(key, ndim, subset=subset)

def aggregate(self, func, *args, **kwargs):
result = super().aggregate(func, *args, **kwargs)
if not self._as_index and (is_list_like(func) or is_dict_like(func)):
result = result.reset_index(level=list(range(len(self._grouper.names))))
return result

agg = aggregate


class Window(BaseWindow):
"""
Expand Down Expand Up @@ -2852,7 +2862,9 @@ def corr(
Rolling.__doc__ = Window.__doc__


class RollingGroupby(BaseWindowGroupby, Rolling):
# error: Definition of "agg" in base class "BaseWindowGroupby" is
# incompatible with definition in base class "Rolling"
class RollingGroupby(BaseWindowGroupby, Rolling): # type: ignore[misc]
"""
Provide a rolling groupby implementation.
"""
Expand Down
36 changes: 36 additions & 0 deletions pandas/tests/window/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,42 @@ def test_as_index_false(self, by, expected_data):
)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize(
"f", ["mean", lambda x: x.mean(), {"agg_col": "mean"}, ["mean"]]
)
def test_aggregate_as_index_false(self, f):
# GH 31007
index = date_range(end="2020-01-01", periods=10)
groupby_col = ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]
df = DataFrame(
{"groupby_col": groupby_col, "agg_col": [1, 1, 0, 1, 0, 0, 0, 0, 1, 0]},
index=index,
)

result = df.groupby("groupby_col", as_index=False).rolling(4).agg(f)
if isinstance(f, list):
result.columns = result.columns.get_level_values(0)

expected = DataFrame(
{
"groupby_col": groupby_col,
"agg_col": [
np.nan,
np.nan,
np.nan,
0.75,
0.5,
np.nan,
np.nan,
np.nan,
0.25,
0.25,
],
},
index=index,
)
tm.assert_frame_equal(result, expected)

def test_nan_and_zero_endpoints(self, any_int_numpy_dtype):
# https://github.com/twosigma/pandas/issues/53
typ = np.dtype(any_int_numpy_dtype).type
Expand Down