Skip to content

Commit b471c29

Browse files
authored
PERF: Rolling.cov/corr (#39388)
1 parent 88bea62 commit b471c29

File tree

3 files changed

+69
-81
lines changed

3 files changed

+69
-81
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ Performance improvements
207207
- Performance improvement in :meth:`Series.mean` for nullable data types (:issue:`34814`)
208208
- Performance improvement in :meth:`Series.isin` for nullable data types (:issue:`38340`)
209209
- Performance improvement in :meth:`DataFrame.corr` for method=kendall (:issue:`28329`)
210+
- Performance improvement in :meth:`core.window.Rolling.corr` and :meth:`core.window.Rolling.cov` (:issue:`39388`)
210211

211212
.. ---------------------------------------------------------------------------
212213

pandas/core/window/expanding.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -82,29 +82,6 @@ def _get_window_indexer(self) -> BaseIndexer:
8282
"""
8383
return ExpandingIndexer()
8484

85-
def _get_cov_corr_window(
86-
self, other: Optional[Union[np.ndarray, FrameOrSeries]] = None, **kwargs
87-
) -> int:
88-
"""
89-
Get the window length over which to perform cov and corr operations.
90-
91-
Parameters
92-
----------
93-
other : object, default None
94-
The other object that is involved in the operation.
95-
Such an object is involved for operations like covariance.
96-
97-
Returns
98-
-------
99-
window : int
100-
The window length.
101-
"""
102-
axis = self.obj._get_axis(self.axis)
103-
length = len(axis) + (other is not None) * len(axis)
104-
105-
other = self.min_periods or -1
106-
return max(length, other)
107-
10885
_agg_see_also_doc = dedent(
10986
"""
11087
See Also
@@ -286,9 +263,10 @@ def corr(
286263
self,
287264
other: Optional[Union[np.ndarray, FrameOrSeries]] = None,
288265
pairwise: Optional[bool] = None,
266+
ddof: int = 1,
289267
**kwargs,
290268
):
291-
return super().corr(other=other, pairwise=pairwise, **kwargs)
269+
return super().corr(other=other, pairwise=pairwise, ddof=ddof, **kwargs)
292270

293271

294272
class ExpandingGroupby(BaseWindowGroupby, Expanding):

pandas/core/window/rolling.py

+66-57
Original file line numberDiff line numberDiff line change
@@ -245,23 +245,6 @@ def __getattr__(self, attr: str):
245245
def _dir_additions(self):
246246
return self.obj._dir_additions()
247247

248-
def _get_cov_corr_window(
249-
self, other: Optional[Union[np.ndarray, FrameOrSeries]] = None
250-
) -> Optional[Union[int, timedelta, BaseOffset, BaseIndexer]]:
251-
"""
252-
Return window length.
253-
254-
Parameters
255-
----------
256-
other :
257-
Used in Expanding
258-
259-
Returns
260-
-------
261-
window : int
262-
"""
263-
return self.window
264-
265248
def __repr__(self) -> str:
266249
"""
267250
Provide a nice str repr of our rolling object.
@@ -1853,32 +1836,38 @@ def cov(self, other=None, pairwise=None, ddof=1, **kwargs):
18531836
other = self._selected_obj
18541837
# only default unset
18551838
pairwise = True if pairwise is None else pairwise
1856-
other = self._shallow_copy(other)
18571839

1858-
# GH 32865. We leverage rolling.mean, so we pass
1859-
# to the rolling constructors the data used when constructing self:
1860-
# window width, frequency data, or a BaseIndexer subclass
1861-
# GH 16058: offset window
1862-
window = self._get_cov_corr_window(other)
1840+
from pandas import Series
18631841

1864-
def _get_cov(X, Y):
1865-
# GH #12373 : rolling functions error on float32 data
1866-
# to avoid potential overflow, cast the data to float64
1867-
X = X.astype("float64")
1868-
Y = Y.astype("float64")
1869-
mean = lambda x: x.rolling(
1870-
window, self.min_periods, center=self.center
1871-
).mean(**kwargs)
1872-
count = (
1873-
(X + Y)
1874-
.rolling(window=window, min_periods=0, center=self.center)
1875-
.count(**kwargs)
1842+
def cov_func(x, y):
1843+
x_array = self._prep_values(x)
1844+
y_array = self._prep_values(y)
1845+
window_indexer = self._get_window_indexer()
1846+
min_periods = (
1847+
self.min_periods
1848+
if self.min_periods is not None
1849+
else window_indexer.window_size
18761850
)
1877-
bias_adj = count / (count - ddof)
1878-
return (mean(X * Y) - mean(X) * mean(Y)) * bias_adj
1851+
start, end = window_indexer.get_window_bounds(
1852+
num_values=len(x_array),
1853+
min_periods=min_periods,
1854+
center=self.center,
1855+
closed=self.closed,
1856+
)
1857+
with np.errstate(all="ignore"):
1858+
mean_x_y = window_aggregations.roll_mean(
1859+
x_array * y_array, start, end, min_periods
1860+
)
1861+
mean_x = window_aggregations.roll_mean(x_array, start, end, min_periods)
1862+
mean_y = window_aggregations.roll_mean(y_array, start, end, min_periods)
1863+
count_x_y = window_aggregations.roll_sum(
1864+
notna(x_array + y_array).astype(np.float64), start, end, 0
1865+
)
1866+
result = (mean_x_y - mean_x * mean_y) * (count_x_y / (count_x_y - ddof))
1867+
return Series(result, index=x.index, name=x.name)
18791868

18801869
return flex_binary_moment(
1881-
self._selected_obj, other._selected_obj, _get_cov, pairwise=bool(pairwise)
1870+
self._selected_obj, other, cov_func, pairwise=bool(pairwise)
18821871
)
18831872

18841873
_shared_docs["corr"] = dedent(
@@ -1991,33 +1980,53 @@ def _get_cov(X, Y):
19911980
"""
19921981
)
19931982

1994-
def corr(self, other=None, pairwise=None, **kwargs):
1983+
def corr(self, other=None, pairwise=None, ddof=1, **kwargs):
19951984
if other is None:
19961985
other = self._selected_obj
19971986
# only default unset
19981987
pairwise = True if pairwise is None else pairwise
1999-
other = self._shallow_copy(other)
20001988

2001-
# GH 32865. We leverage rolling.cov and rolling.std here, so we pass
2002-
# to the rolling constructors the data used when constructing self:
2003-
# window width, frequency data, or a BaseIndexer subclass
2004-
# GH 16058: offset window
2005-
window = self._get_cov_corr_window(other)
1989+
from pandas import Series
20061990

2007-
def _get_corr(a, b):
2008-
a = a.rolling(
2009-
window=window, min_periods=self.min_periods, center=self.center
1991+
def corr_func(x, y):
1992+
x_array = self._prep_values(x)
1993+
y_array = self._prep_values(y)
1994+
window_indexer = self._get_window_indexer()
1995+
min_periods = (
1996+
self.min_periods
1997+
if self.min_periods is not None
1998+
else window_indexer.window_size
20101999
)
2011-
b = b.rolling(
2012-
window=window, min_periods=self.min_periods, center=self.center
2000+
start, end = window_indexer.get_window_bounds(
2001+
num_values=len(x_array),
2002+
min_periods=min_periods,
2003+
center=self.center,
2004+
closed=self.closed,
20132005
)
2014-
# GH 31286: Through using var instead of std we can avoid numerical
2015-
# issues when the result of var is within floating proint precision
2016-
# while std is not.
2017-
return a.cov(b, **kwargs) / (a.var(**kwargs) * b.var(**kwargs)) ** 0.5
2006+
with np.errstate(all="ignore"):
2007+
mean_x_y = window_aggregations.roll_mean(
2008+
x_array * y_array, start, end, min_periods
2009+
)
2010+
mean_x = window_aggregations.roll_mean(x_array, start, end, min_periods)
2011+
mean_y = window_aggregations.roll_mean(y_array, start, end, min_periods)
2012+
count_x_y = window_aggregations.roll_sum(
2013+
notna(x_array + y_array).astype(np.float64), start, end, 0
2014+
)
2015+
x_var = window_aggregations.roll_var(
2016+
x_array, start, end, min_periods, ddof
2017+
)
2018+
y_var = window_aggregations.roll_var(
2019+
y_array, start, end, min_periods, ddof
2020+
)
2021+
numerator = (mean_x_y - mean_x * mean_y) * (
2022+
count_x_y / (count_x_y - ddof)
2023+
)
2024+
denominator = (x_var * y_var) ** 0.5
2025+
result = numerator / denominator
2026+
return Series(result, index=x.index, name=x.name)
20182027

20192028
return flex_binary_moment(
2020-
self._selected_obj, other._selected_obj, _get_corr, pairwise=bool(pairwise)
2029+
self._selected_obj, other, corr_func, pairwise=bool(pairwise)
20212030
)
20222031

20232032

@@ -2254,8 +2263,8 @@ def cov(self, other=None, pairwise=None, ddof=1, **kwargs):
22542263

22552264
@Substitution(name="rolling")
22562265
@Appender(_shared_docs["corr"])
2257-
def corr(self, other=None, pairwise=None, **kwargs):
2258-
return super().corr(other=other, pairwise=pairwise, **kwargs)
2266+
def corr(self, other=None, pairwise=None, ddof=1, **kwargs):
2267+
return super().corr(other=other, pairwise=pairwise, ddof=ddof, **kwargs)
22592268

22602269

22612270
Rolling.__doc__ = Window.__doc__

0 commit comments

Comments
 (0)