Skip to content

Commit 63651f3

Browse files
authored
BUG: support corr and cov functions for custom BaseIndexer rolling windows (#33804)
1 parent b828d74 commit 63651f3

File tree

4 files changed

+58
-52
lines changed

4 files changed

+58
-52
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,7 @@ Other API changes
224224
- Added :meth:`DataFrame.value_counts` (:issue:`5377`)
225225
- :meth:`Groupby.groups` now returns an abbreviated representation when called on large dataframes (:issue:`1135`)
226226
- ``loc`` lookups with an object-dtype :class:`Index` and an integer key will now raise ``KeyError`` instead of ``TypeError`` when key is missing (:issue:`31905`)
227-
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``cov``, ``corr`` will now raise a ``NotImplementedError`` (:issue:`32865`)
228-
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``count``, ``min``, ``max``, ``median``, ``skew`` will now return correct results for any monotonic :func:`pandas.api.indexers.BaseIndexer` descendant (:issue:`32865`)
227+
- Using a :func:`pandas.api.indexers.BaseIndexer` with ``count``, ``min``, ``max``, ``median``, ``skew``, ``cov``, ``corr`` will now return correct results for any monotonic :func:`pandas.api.indexers.BaseIndexer` descendant (:issue:`32865`)
229228
- Added a :func:`pandas.api.indexers.FixedForwardWindowIndexer` class to support forward-looking windows during ``rolling`` operations.
230229
-
231230

pandas/core/window/common.py

-22
Original file line numberDiff line numberDiff line change
@@ -324,25 +324,3 @@ def func(arg, window, min_periods=None):
324324
return cfunc(arg, window, min_periods)
325325

326326
return func
327-
328-
329-
def validate_baseindexer_support(func_name: Optional[str]) -> None:
330-
# GH 32865: These functions work correctly with a BaseIndexer subclass
331-
BASEINDEXER_WHITELIST = {
332-
"count",
333-
"min",
334-
"max",
335-
"mean",
336-
"sum",
337-
"median",
338-
"std",
339-
"var",
340-
"skew",
341-
"kurt",
342-
"quantile",
343-
}
344-
if isinstance(func_name, str) and func_name not in BASEINDEXER_WHITELIST:
345-
raise NotImplementedError(
346-
f"{func_name} is not supported with using a BaseIndexer "
347-
f"subclasses. You can use .apply() with {func_name}."
348-
)

pandas/core/window/rolling.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
calculate_center_offset,
4949
calculate_min_periods,
5050
get_weighted_roll_func,
51-
validate_baseindexer_support,
5251
zsqrt,
5352
)
5453
from pandas.core.window.indexers import (
@@ -393,12 +392,11 @@ def _get_cython_func_type(self, func: str) -> Callable:
393392
return self._get_roll_func(f"{func}_variable")
394393
return partial(self._get_roll_func(f"{func}_fixed"), win=self._get_window())
395394

396-
def _get_window_indexer(self, window: int, func_name: Optional[str]) -> BaseIndexer:
395+
def _get_window_indexer(self, window: int) -> BaseIndexer:
397396
"""
398397
Return an indexer class that will compute the window start and end bounds
399398
"""
400399
if isinstance(self.window, BaseIndexer):
401-
validate_baseindexer_support(func_name)
402400
return self.window
403401
if self.is_freq_type:
404402
return VariableWindowIndexer(index_array=self._on.asi8, window_size=window)
@@ -444,7 +442,7 @@ def _apply(
444442

445443
blocks, obj = self._create_blocks()
446444
block_list = list(blocks)
447-
window_indexer = self._get_window_indexer(window, name)
445+
window_indexer = self._get_window_indexer(window)
448446

449447
results = []
450448
exclude: List[Scalar] = []
@@ -1632,20 +1630,23 @@ def quantile(self, quantile, interpolation="linear", **kwargs):
16321630
"""
16331631

16341632
def cov(self, other=None, pairwise=None, ddof=1, **kwargs):
1635-
if isinstance(self.window, BaseIndexer):
1636-
validate_baseindexer_support("cov")
1637-
16381633
if other is None:
16391634
other = self._selected_obj
16401635
# only default unset
16411636
pairwise = True if pairwise is None else pairwise
16421637
other = self._shallow_copy(other)
16431638

1644-
# GH 16058: offset window
1645-
if self.is_freq_type:
1646-
window = self.win_freq
1639+
# GH 32865. We leverage rolling.mean, so we pass
1640+
# to the rolling constructors the data used when constructing self:
1641+
# window width, frequency data, or a BaseIndexer subclass
1642+
if isinstance(self.window, BaseIndexer):
1643+
window = self.window
16471644
else:
1648-
window = self._get_window(other)
1645+
# GH 16058: offset window
1646+
if self.is_freq_type:
1647+
window = self.win_freq
1648+
else:
1649+
window = self._get_window(other)
16491650

16501651
def _get_cov(X, Y):
16511652
# GH #12373 : rolling functions error on float32 data
@@ -1778,15 +1779,19 @@ def _get_cov(X, Y):
17781779
)
17791780

17801781
def corr(self, other=None, pairwise=None, **kwargs):
1781-
if isinstance(self.window, BaseIndexer):
1782-
validate_baseindexer_support("corr")
1783-
17841782
if other is None:
17851783
other = self._selected_obj
17861784
# only default unset
17871785
pairwise = True if pairwise is None else pairwise
17881786
other = self._shallow_copy(other)
1789-
window = self._get_window(other) if not self.is_freq_type else self.win_freq
1787+
1788+
# GH 32865. We leverage rolling.cov and rolling.std here, so we pass
1789+
# to the rolling constructors the data used when constructing self:
1790+
# window width, frequency data, or a BaseIndexer subclass
1791+
if isinstance(self.window, BaseIndexer):
1792+
window = self.window
1793+
else:
1794+
window = self._get_window(other) if not self.is_freq_type else self.win_freq
17901795

17911796
def _get_corr(a, b):
17921797
a = a.rolling(

pandas/tests/window/test_base_indexer.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -82,19 +82,6 @@ def get_window_bounds(self, num_values, min_periods, center, closed):
8282
df.rolling(indexer, win_type="boxcar")
8383

8484

85-
@pytest.mark.parametrize("func", ["cov", "corr"])
86-
def test_notimplemented_functions(func):
87-
# GH 32865
88-
class CustomIndexer(BaseIndexer):
89-
def get_window_bounds(self, num_values, min_periods, center, closed):
90-
return np.array([0, 1]), np.array([1, 2])
91-
92-
df = DataFrame({"values": range(2)})
93-
indexer = CustomIndexer()
94-
with pytest.raises(NotImplementedError, match=f"{func} is not supported"):
95-
getattr(df.rolling(indexer), func)()
96-
97-
9885
@pytest.mark.parametrize("constructor", [Series, DataFrame])
9986
@pytest.mark.parametrize(
10087
"func,np_func,expected,np_kwargs",
@@ -210,3 +197,40 @@ def test_rolling_forward_skewness(constructor):
210197
]
211198
)
212199
tm.assert_equal(result, expected)
200+
201+
202+
@pytest.mark.parametrize(
203+
"func,expected",
204+
[
205+
("cov", [2.0, 2.0, 2.0, 97.0, 2.0, -93.0, 2.0, 2.0, np.nan, np.nan],),
206+
(
207+
"corr",
208+
[
209+
1.0,
210+
1.0,
211+
1.0,
212+
0.8704775290207161,
213+
0.018229084250926637,
214+
-0.861357304646493,
215+
1.0,
216+
1.0,
217+
np.nan,
218+
np.nan,
219+
],
220+
),
221+
],
222+
)
223+
def test_rolling_forward_cov_corr(func, expected):
224+
values1 = np.arange(10).reshape(-1, 1)
225+
values2 = values1 * 2
226+
values1[5, 0] = 100
227+
values = np.concatenate([values1, values2], axis=1)
228+
229+
indexer = FixedForwardWindowIndexer(window_size=3)
230+
rolling = DataFrame(values).rolling(window=indexer, min_periods=3)
231+
# We are interested in checking only pairwise covariance / correlation
232+
result = getattr(rolling, func)().loc[(slice(None), 1), 0]
233+
result = result.reset_index(drop=True)
234+
expected = Series(expected)
235+
expected.name = result.name
236+
tm.assert_equal(result, expected)

0 commit comments

Comments
 (0)