diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index a85b84aad9f94..5ccf3015ac257 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -408,6 +408,7 @@ Groupby/resample/rolling - Bug in :meth:`pandas.DataFrame.rolling` operation along rows (``axis=1``) incorrectly omits columns containing ``float16`` and ``float32`` (:issue:`41779`) - Bug in :meth:`Resampler.aggregate` did not allow the use of Named Aggregation (:issue:`32803`) - Bug in :meth:`Series.rolling` when the :class:`Series` ``dtype`` was ``Int64`` (:issue:`43016`) +- Bug in :meth:`DataFrame.rolling.corr` when the :class:`DataFrame` columns was a :class:`MultiIndex` (:issue:`21157`) - Bug in :meth:`DataFrame.groupby.rolling` when specifying ``on`` and calling ``__getitem__`` would subsequently return incorrect results (:issue:`43355`) Reshaping diff --git a/pandas/core/window/common.py b/pandas/core/window/common.py index e0720c5d86df1..15144116fa924 100644 --- a/pandas/core/window/common.py +++ b/pandas/core/window/common.py @@ -83,8 +83,24 @@ def dataframe_from_int_dict(data, frame_template): # mypy needs to know columns is a MultiIndex, Index doesn't # have levels attribute arg2.columns = cast(MultiIndex, arg2.columns) - result.index = MultiIndex.from_product( - arg2.columns.levels + [result_index] + # GH 21157: Equivalent to MultiIndex.from_product( + # [result_index], , + # ) + # A normal MultiIndex.from_product will produce too many + # combinations. + result_level = np.tile( + result_index, len(result) // len(result_index) + ) + arg2_levels = ( + np.repeat( + arg2.columns.get_level_values(i), + len(result) // len(arg2.columns), + ) + for i in range(arg2.columns.nlevels) + ) + result_names = list(arg2.columns.names) + [result_index.name] + result.index = MultiIndex.from_arrays( + [*arg2_levels, result_level], names=result_names ) # GH 34440 num_levels = len(result.index.levels) diff --git a/pandas/tests/window/test_pairwise.py b/pandas/tests/window/test_pairwise.py index a0d24a061fc4a..f43d7ec99e312 100644 --- a/pandas/tests/window/test_pairwise.py +++ b/pandas/tests/window/test_pairwise.py @@ -222,3 +222,18 @@ def test_cov_mulittindex(self): ) tm.assert_frame_equal(result, expected) + + def test_multindex_columns_pairwise_func(self): + # GH 21157 + columns = MultiIndex.from_arrays([["M", "N"], ["P", "Q"]], names=["a", "b"]) + df = DataFrame(np.ones((5, 2)), columns=columns) + result = df.rolling(3).corr() + expected = DataFrame( + np.nan, + index=MultiIndex.from_arrays( + [np.repeat(np.arange(5), 2), ["M", "N"] * 5, ["P", "Q"] * 5], + names=[None, "a", "b"], + ), + columns=columns, + ) + tm.assert_frame_equal(result, expected)