Skip to content

Commit a37f1a4

Browse files
BUG/REG: RollingGroupby MultiIndex levels dropped (pandas-dev#38737)
Co-authored-by: Simon Hawkins <[email protected]>
1 parent beb4f1b commit a37f1a4

File tree

4 files changed

+55
-23
lines changed

4 files changed

+55
-23
lines changed

doc/source/whatsnew/v1.2.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Fixed regressions
1616
~~~~~~~~~~~~~~~~~
1717
- The deprecated attributes ``_AXIS_NAMES`` and ``_AXIS_NUMBERS`` of :class:`DataFrame` and :class:`Series` will no longer show up in ``dir`` or ``inspect.getmembers`` calls (:issue:`38740`)
1818
- :meth:`to_csv` created corrupted zip files when there were more rows than ``chunksize`` (issue:`38714`)
19+
- Fixed a regression in ``groupby().rolling()`` where :class:`MultiIndex` levels were dropped (:issue:`38523`)
1920
- Bug in repr of float-like strings of an ``object`` dtype having trailing 0's truncated after the decimal (:issue:`38708`)
2021
-
2122

pandas/core/shared_docs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
Note this does not influence the order of observations within each
109109
group. Groupby preserves the order of rows within each group.
110110
group_keys : bool, default True
111-
When calling apply, add group keys to index to identify pieces.
111+
When calling ``groupby().apply()``, add group keys to index to identify pieces.
112112
squeeze : bool, default False
113113
Reduce the dimensionality of the return type if possible,
114114
otherwise return a consistent type.

pandas/core/window/rolling.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -769,28 +769,22 @@ def _apply(
769769
numba_cache_key,
770770
**kwargs,
771771
)
772-
# Reconstruct the resulting MultiIndex from tuples
772+
# Reconstruct the resulting MultiIndex
773773
# 1st set of levels = group by labels
774-
# 2nd set of levels = original index
775-
# Ignore 2nd set of levels if a group by label include an index level
776-
result_index_names = [
777-
grouping.name for grouping in self._groupby.grouper._groupings
778-
]
779-
grouped_object_index = None
774+
# 2nd set of levels = original DataFrame/Series index
775+
grouped_object_index = self.obj.index
776+
grouped_index_name = [*grouped_object_index.names]
777+
groupby_keys = [grouping.name for grouping in self._groupby.grouper._groupings]
778+
result_index_names = groupby_keys + grouped_index_name
780779

781-
column_keys = [
780+
drop_columns = [
782781
key
783-
for key in result_index_names
782+
for key in groupby_keys
784783
if key not in self.obj.index.names or key is None
785784
]
786-
787-
if len(column_keys) == len(result_index_names):
788-
grouped_object_index = self.obj.index
789-
grouped_index_name = [*grouped_object_index.names]
790-
result_index_names += grouped_index_name
791-
else:
792-
# Our result will have still kept the column in the result
793-
result = result.drop(columns=column_keys, errors="ignore")
785+
if len(drop_columns) != len(groupby_keys):
786+
# Our result will have kept groupby columns which should be dropped
787+
result = result.drop(columns=drop_columns, errors="ignore")
794788

795789
codes = self._groupby.grouper.codes
796790
levels = self._groupby.grouper.levels

pandas/tests/window/test_groupby.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -556,23 +556,31 @@ def test_groupby_rolling_nans_in_index(self, rollings, key):
556556
with pytest.raises(ValueError, match=f"{key} must be monotonic"):
557557
df.groupby("c").rolling("60min", **rollings)
558558

559-
def test_groupby_rolling_group_keys(self):
559+
@pytest.mark.parametrize("group_keys", [True, False])
560+
def test_groupby_rolling_group_keys(self, group_keys):
560561
# GH 37641
562+
# GH 38523: GH 37641 actually was not a bug.
563+
# group_keys only applies to groupby.apply directly
561564
arrays = [["val1", "val1", "val2"], ["val1", "val1", "val2"]]
562565
index = MultiIndex.from_arrays(arrays, names=("idx1", "idx2"))
563566

564567
s = Series([1, 2, 3], index=index)
565-
result = s.groupby(["idx1", "idx2"], group_keys=False).rolling(1).mean()
568+
result = s.groupby(["idx1", "idx2"], group_keys=group_keys).rolling(1).mean()
566569
expected = Series(
567570
[1.0, 2.0, 3.0],
568571
index=MultiIndex.from_tuples(
569-
[("val1", "val1"), ("val1", "val1"), ("val2", "val2")],
570-
names=["idx1", "idx2"],
572+
[
573+
("val1", "val1", "val1", "val1"),
574+
("val1", "val1", "val1", "val1"),
575+
("val2", "val2", "val2", "val2"),
576+
],
577+
names=["idx1", "idx2", "idx1", "idx2"],
571578
),
572579
)
573580
tm.assert_series_equal(result, expected)
574581

575582
def test_groupby_rolling_index_level_and_column_label(self):
583+
# The groupby keys should not appear as a resulting column
576584
arrays = [["val1", "val1", "val2"], ["val1", "val1", "val2"]]
577585
index = MultiIndex.from_arrays(arrays, names=("idx1", "idx2"))
578586

@@ -581,7 +589,12 @@ def test_groupby_rolling_index_level_and_column_label(self):
581589
expected = DataFrame(
582590
{"B": [0.0, 1.0, 2.0]},
583591
index=MultiIndex.from_tuples(
584-
[("val1", 1), ("val1", 1), ("val2", 2)], names=["idx1", "A"]
592+
[
593+
("val1", 1, "val1", "val1"),
594+
("val1", 1, "val1", "val1"),
595+
("val2", 2, "val2", "val2"),
596+
],
597+
names=["idx1", "A", "idx1", "idx2"],
585598
),
586599
)
587600
tm.assert_frame_equal(result, expected)
@@ -640,6 +653,30 @@ def test_groupby_rolling_resulting_multiindex(self):
640653
)
641654
tm.assert_index_equal(result.index, expected_index)
642655

656+
def test_groupby_level(self):
657+
# GH 38523
658+
arrays = [
659+
["Falcon", "Falcon", "Parrot", "Parrot"],
660+
["Captive", "Wild", "Captive", "Wild"],
661+
]
662+
index = MultiIndex.from_arrays(arrays, names=("Animal", "Type"))
663+
df = DataFrame({"Max Speed": [390.0, 350.0, 30.0, 20.0]}, index=index)
664+
result = df.groupby(level=0)["Max Speed"].rolling(2).sum()
665+
expected = Series(
666+
[np.nan, 740.0, np.nan, 50.0],
667+
index=MultiIndex.from_tuples(
668+
[
669+
("Falcon", "Falcon", "Captive"),
670+
("Falcon", "Falcon", "Wild"),
671+
("Parrot", "Parrot", "Captive"),
672+
("Parrot", "Parrot", "Wild"),
673+
],
674+
names=["Animal", "Animal", "Type"],
675+
),
676+
name="Max Speed",
677+
)
678+
tm.assert_series_equal(result, expected)
679+
643680

644681
class TestExpanding:
645682
def setup_method(self):

0 commit comments

Comments
 (0)