Skip to content

Commit 766ea00

Browse files
Backport PR #44824: BUG: Fix regression in groupby.rolling.corr/cov when other is same size as each group (#44848)
Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 1804780 commit 766ea00

File tree

3 files changed

+39
-7
lines changed

3 files changed

+39
-7
lines changed

doc/source/whatsnew/v1.3.5.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Fixed regressions
2121
- Fixed performance regression in :func:`read_csv` (:issue:`44106`)
2222
- Fixed regression in :meth:`Series.duplicated` and :meth:`Series.drop_duplicates` when Series has :class:`Categorical` dtype with boolean categories (:issue:`44351`)
2323
- Fixed regression in :meth:`.GroupBy.sum` with ``timedelta64[ns]`` dtype containing ``NaT`` failing to treat that value as NA (:issue:`42659`)
24-
-
24+
- Fixed regression in :meth:`.RollingGroupby.cov` and :meth:`.RollingGroupby.corr` when ``other`` had the same shape as each group would incorrectly return superfluous groups in the result (:issue:`42915`)
2525

2626
.. ---------------------------------------------------------------------------
2727

pandas/core/window/rolling.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,11 @@ def _apply_pairwise(
658658
target = self._create_data(target)
659659
result = super()._apply_pairwise(target, other, pairwise, func)
660660
# 1) Determine the levels + codes of the groupby levels
661-
if other is not None:
662-
# When we have other, we must reindex (expand) the result
661+
if other is not None and not all(
662+
len(group) == len(other) for group in self._grouper.indices.values()
663+
):
664+
# GH 42915
665+
# len(other) != len(any group), so must reindex (expand) the result
663666
# from flex_binary_moment to a "transform"-like result
664667
# per groupby combination
665668
old_result_len = len(result)
@@ -681,10 +684,9 @@ def _apply_pairwise(
681684
codes, levels = factorize(labels)
682685
groupby_codes.append(codes)
683686
groupby_levels.append(levels)
684-
685687
else:
686-
# When we evaluate the pairwise=True result, repeat the groupby
687-
# labels by the number of columns in the original object
688+
# pairwise=True or len(other) == len(each group), so repeat
689+
# the groupby labels by the number of columns in the original object
688690
groupby_codes = self._grouper.codes
689691
# error: Incompatible types in assignment (expression has type
690692
# "List[Index]", variable has type "List[Union[ndarray, Index]]")

pandas/tests/window/test_groupby.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,33 @@ def test_rolling_quantile(self, interpolation):
122122
expected.index = expected_index
123123
tm.assert_frame_equal(result, expected)
124124

125+
@pytest.mark.parametrize("f, expected_val", [["corr", 1], ["cov", 0.5]])
126+
def test_rolling_corr_cov_other_same_size_as_groups(self, f, expected_val):
127+
# GH 42915
128+
df = DataFrame(
129+
{"value": range(10), "idx1": [1] * 5 + [2] * 5, "idx2": [1, 2, 3, 4, 5] * 2}
130+
).set_index(["idx1", "idx2"])
131+
other = DataFrame({"value": range(5), "idx2": [1, 2, 3, 4, 5]}).set_index(
132+
"idx2"
133+
)
134+
result = getattr(df.groupby(level=0).rolling(2), f)(other)
135+
expected_data = ([np.nan] + [expected_val] * 4) * 2
136+
expected = DataFrame(
137+
expected_data,
138+
columns=["value"],
139+
index=MultiIndex.from_arrays(
140+
[
141+
[1] * 5 + [2] * 5,
142+
[1] * 5 + [2] * 5,
143+
list(range(1, 6)) * 2,
144+
],
145+
names=["idx1", "idx1", "idx2"],
146+
),
147+
)
148+
tm.assert_frame_equal(result, expected)
149+
125150
@pytest.mark.parametrize("f", ["corr", "cov"])
126-
def test_rolling_corr_cov(self, f):
151+
def test_rolling_corr_cov_other_diff_size_as_groups(self, f):
127152
g = self.frame.groupby("A")
128153
r = g.rolling(window=4)
129154

@@ -138,6 +163,11 @@ def func(x):
138163
expected["A"] = np.nan
139164
tm.assert_frame_equal(result, expected)
140165

166+
@pytest.mark.parametrize("f", ["corr", "cov"])
167+
def test_rolling_corr_cov_pairwise(self, f):
168+
g = self.frame.groupby("A")
169+
r = g.rolling(window=4)
170+
141171
result = getattr(r.B, f)(pairwise=True)
142172

143173
def func(x):

0 commit comments

Comments
 (0)