Skip to content

Commit 6a5dd0f

Browse files
authored
BUG: rolling.corr with MultiIndex columns (#43261)
* BUG: rolling.corr with MultiIndex columns * Add commentary of fix * Trigger CI * Use from_arrays * Consider reordering
1 parent b48635e commit 6a5dd0f

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ Groupby/resample/rolling
409409
- Bug in :meth:`pandas.DataFrame.rolling` operation along rows (``axis=1``) incorrectly omits columns containing ``float16`` and ``float32`` (:issue:`41779`)
410410
- Bug in :meth:`Resampler.aggregate` did not allow the use of Named Aggregation (:issue:`32803`)
411411
- Bug in :meth:`Series.rolling` when the :class:`Series` ``dtype`` was ``Int64`` (:issue:`43016`)
412+
- Bug in :meth:`DataFrame.rolling.corr` when the :class:`DataFrame` columns was a :class:`MultiIndex` (:issue:`21157`)
412413
- Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`)
413414

414415
Reshaping

pandas/core/window/common.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,24 @@ def dataframe_from_int_dict(data, frame_template):
8383
# mypy needs to know columns is a MultiIndex, Index doesn't
8484
# have levels attribute
8585
arg2.columns = cast(MultiIndex, arg2.columns)
86-
result.index = MultiIndex.from_product(
87-
arg2.columns.levels + [result_index]
86+
# GH 21157: Equivalent to MultiIndex.from_product(
87+
# [result_index], <unique combinations of arg2.columns.levels>,
88+
# )
89+
# A normal MultiIndex.from_product will produce too many
90+
# combinations.
91+
result_level = np.tile(
92+
result_index, len(result) // len(result_index)
93+
)
94+
arg2_levels = (
95+
np.repeat(
96+
arg2.columns.get_level_values(i),
97+
len(result) // len(arg2.columns),
98+
)
99+
for i in range(arg2.columns.nlevels)
100+
)
101+
result_names = list(arg2.columns.names) + [result_index.name]
102+
result.index = MultiIndex.from_arrays(
103+
[*arg2_levels, result_level], names=result_names
88104
)
89105
# GH 34440
90106
num_levels = len(result.index.levels)

pandas/tests/window/test_pairwise.py

+15
Original file line numberDiff line numberDiff line change
@@ -222,3 +222,18 @@ def test_cov_mulittindex(self):
222222
)
223223

224224
tm.assert_frame_equal(result, expected)
225+
226+
def test_multindex_columns_pairwise_func(self):
227+
# GH 21157
228+
columns = MultiIndex.from_arrays([["M", "N"], ["P", "Q"]], names=["a", "b"])
229+
df = DataFrame(np.ones((5, 2)), columns=columns)
230+
result = df.rolling(3).corr()
231+
expected = DataFrame(
232+
np.nan,
233+
index=MultiIndex.from_arrays(
234+
[np.repeat(np.arange(5), 2), ["M", "N"] * 5, ["P", "Q"] * 5],
235+
names=[None, "a", "b"],
236+
),
237+
columns=columns,
238+
)
239+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)