Skip to content

Commit 6a9ceac

Browse files
authored
BUG: exponential moving window covariance fails for multiIndexed DataFrame (#34943)
* added test for df.ewm.cov with multiindex * BUG: fixed _flex_binary_moment for multiindex * added reference to GH issue * DOC: updated whatnew * DOC: moved note to rolling section of whatsnew * changed df to fixed seed, linearly spaced ints * removed extraneous comment * TST: hardcoded expected df for test_multiindex_cov * TST: cleaned up comment + blank line * TST: clean up index definition
1 parent e37ff6e commit 6a9ceac

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,7 @@ Groupby/resample/rolling
10611061
- Bug in :meth:`SeriesGroupBy.agg` where any column name was accepted in the named aggregation of ``SeriesGroupBy`` previously. The behaviour now allows only ``str`` and callables else would raise ``TypeError``. (:issue:`34422`)
10621062
- Bug in :meth:`DataFrame.groupby` lost index, when one of the ``agg`` keys referenced an empty list (:issue:`32580`)
10631063
- Bug in :meth:`Rolling.apply` where ``center=True`` was ignored when ``engine='numba'`` was specified (:issue:`34784`)
1064+
- Bug in :meth:`DataFrame.ewm.cov` was throwing ``AssertionError`` for :class:`MultiIndex` inputs (:issue:`34440`)
10641065

10651066
Reshaping
10661067
^^^^^^^^^

pandas/core/window/common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ def dataframe_from_int_dict(data, frame_template):
179179
result.index = MultiIndex.from_product(
180180
arg2.columns.levels + [result_index]
181181
)
182-
result = result.reorder_levels([2, 0, 1]).sort_index()
182+
# GH 34440
183+
num_levels = len(result.index.levels)
184+
new_order = [num_levels - 1] + list(range(num_levels - 1))
185+
result = result.reorder_levels(new_order).sort_index()
183186
else:
184187
result.index = MultiIndex.from_product(
185188
[range(len(arg2.columns)), range(len(result_index))]

pandas/tests/window/test_pairwise.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55

6-
from pandas import DataFrame, Series, date_range
6+
from pandas import DataFrame, MultiIndex, Series, date_range
77
import pandas._testing as tm
88
from pandas.core.algorithms import safe_sort
99

@@ -189,3 +189,28 @@ def test_corr_freq_memory_error(self):
189189
result = s.rolling("12H").corr(s)
190190
expected = Series([np.nan] * 5, index=date_range("2020", periods=5))
191191
tm.assert_series_equal(result, expected)
192+
193+
def test_cov_mulittindex(self):
194+
# GH 34440
195+
196+
columns = MultiIndex.from_product([list("ab"), list("xy"), list("AB")])
197+
index = range(3)
198+
df = DataFrame(np.arange(24).reshape(3, 8), index=index, columns=columns,)
199+
200+
result = df.ewm(alpha=0.1).cov()
201+
202+
index = MultiIndex.from_product([range(3), list("ab"), list("xy"), list("AB")])
203+
columns = MultiIndex.from_product([list("ab"), list("xy"), list("AB")])
204+
expected = DataFrame(
205+
np.vstack(
206+
(
207+
np.full((8, 8), np.NaN),
208+
np.full((8, 8), 32.000000),
209+
np.full((8, 8), 63.881919),
210+
)
211+
),
212+
index=index,
213+
columns=columns,
214+
)
215+
216+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)