Skip to content

Commit 863a6a4

Browse files
mroeschkejreback
authored andcommitted
CLN: Simplify rolling.py helper functions (#30672)
1 parent 9da81ac commit 863a6a4

File tree

3 files changed

+43
-92
lines changed

3 files changed

+43
-92
lines changed

pandas/core/window/common.py

+16-28
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _flex_binary_moment(arg1, arg2, f, pairwise=False):
105105
if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
106106
arg2, (np.ndarray, ABCSeries)
107107
):
108-
X, Y = _prep_binary(arg1, arg2)
108+
X, Y = prep_binary(arg1, arg2)
109109
return f(X, Y)
110110

111111
elif isinstance(arg1, ABCDataFrame):
@@ -152,7 +152,7 @@ def dataframe_from_int_dict(data, frame_template):
152152
results[i][j] = results[j][i]
153153
else:
154154
results[i][j] = f(
155-
*_prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
155+
*prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
156156
)
157157

158158
from pandas import concat
@@ -213,7 +213,7 @@ def dataframe_from_int_dict(data, frame_template):
213213
raise ValueError("'pairwise' is not True/False")
214214
else:
215215
results = {
216-
i: f(*_prep_binary(arg1.iloc[:, i], arg2))
216+
i: f(*prep_binary(arg1.iloc[:, i], arg2))
217217
for i, col in enumerate(arg1.columns)
218218
}
219219
return dataframe_from_int_dict(results, arg1)
@@ -250,31 +250,10 @@ def _get_center_of_mass(comass, span, halflife, alpha):
250250
return float(comass)
251251

252252

253-
def _offset(window, center):
253+
def calculate_center_offset(window):
254254
if not is_integer(window):
255255
window = len(window)
256-
offset = (window - 1) / 2.0 if center else 0
257-
try:
258-
return int(offset)
259-
except TypeError:
260-
return offset.astype(int)
261-
262-
263-
def _require_min_periods(p):
264-
def _check_func(minp, window):
265-
if minp is None:
266-
return window
267-
else:
268-
return max(p, minp)
269-
270-
return _check_func
271-
272-
273-
def _use_window(minp, window):
274-
if minp is None:
275-
return window
276-
else:
277-
return minp
256+
return int((window - 1) / 2.0)
278257

279258

280259
def calculate_min_periods(
@@ -312,7 +291,7 @@ def calculate_min_periods(
312291
return max(min_periods, floor)
313292

314293

315-
def _zsqrt(x):
294+
def zsqrt(x):
316295
with np.errstate(all="ignore"):
317296
result = np.sqrt(x)
318297
mask = x < 0
@@ -327,7 +306,7 @@ def _zsqrt(x):
327306
return result
328307

329308

330-
def _prep_binary(arg1, arg2):
309+
def prep_binary(arg1, arg2):
331310
if not isinstance(arg2, type(arg1)):
332311
raise Exception("Input arrays must be of the same type!")
333312

@@ -336,3 +315,12 @@ def _prep_binary(arg1, arg2):
336315
Y = arg2 + 0 * arg1
337316

338317
return X, Y
318+
319+
320+
def get_weighted_roll_func(cfunc: Callable) -> Callable:
321+
def func(arg, window, min_periods=None):
322+
if min_periods is None:
323+
min_periods = len(window)
324+
return cfunc(arg, window, min_periods)
325+
326+
return func

pandas/core/window/ewm.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
from pandas.core.dtypes.generic import ABCDataFrame
1010

1111
from pandas.core.base import DataError
12-
from pandas.core.window.common import _doc_template, _get_center_of_mass, _shared_docs
13-
from pandas.core.window.rolling import _flex_binary_moment, _Rolling, _zsqrt
12+
from pandas.core.window.common import (
13+
_doc_template,
14+
_get_center_of_mass,
15+
_shared_docs,
16+
zsqrt,
17+
)
18+
from pandas.core.window.rolling import _flex_binary_moment, _Rolling
1419

1520
_bias_template = """
1621
Parameters
@@ -269,7 +274,7 @@ def std(self, bias=False, *args, **kwargs):
269274
Exponential weighted moving stddev.
270275
"""
271276
nv.validate_window_func("std", args, kwargs)
272-
return _zsqrt(self.var(bias=bias, **kwargs))
277+
return zsqrt(self.var(bias=bias, **kwargs))
273278

274279
vol = std
275280

@@ -390,7 +395,7 @@ def _cov(x, y):
390395
cov = _cov(x_values, y_values)
391396
x_var = _cov(x_values, x_values)
392397
y_var = _cov(y_values, y_values)
393-
corr = cov / _zsqrt(x_var * y_var)
398+
corr = cov / zsqrt(x_var * y_var)
394399
return X._wrap_result(corr)
395400

396401
return _flex_binary_moment(

pandas/core/window/rolling.py

+18-60
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
is_integer_dtype,
2525
is_list_like,
2626
is_scalar,
27-
is_timedelta64_dtype,
2827
needs_i8_conversion,
2928
)
3029
from pandas.core.dtypes.generic import (
@@ -43,11 +42,11 @@
4342
WindowGroupByMixin,
4443
_doc_template,
4544
_flex_binary_moment,
46-
_offset,
4745
_shared_docs,
48-
_use_window,
49-
_zsqrt,
46+
calculate_center_offset,
5047
calculate_min_periods,
48+
get_weighted_roll_func,
49+
zsqrt,
5150
)
5251
from pandas.core.window.indexers import (
5352
BaseIndexer,
@@ -252,19 +251,6 @@ def __iter__(self):
252251
url = "https://github.com/pandas-dev/pandas/issues/11704"
253252
raise NotImplementedError(f"See issue #11704 {url}")
254253

255-
def _get_index(self) -> Optional[np.ndarray]:
256-
"""
257-
Return integer representations as an ndarray if index is frequency.
258-
259-
Returns
260-
-------
261-
None or ndarray
262-
"""
263-
264-
if self.is_freq_type:
265-
return self._on.asi8
266-
return None
267-
268254
def _prep_values(self, values: Optional[np.ndarray] = None) -> np.ndarray:
269255
"""Convert input to numpy arrays for Cython routines"""
270256
if values is None:
@@ -305,17 +291,6 @@ def _wrap_result(self, result, block=None, obj=None):
305291

306292
if isinstance(result, np.ndarray):
307293

308-
# coerce if necessary
309-
if block is not None:
310-
if is_timedelta64_dtype(block.values.dtype):
311-
# TODO: do we know what result.dtype is at this point?
312-
# i.e. can we just do an astype?
313-
from pandas import to_timedelta
314-
315-
result = to_timedelta(result.ravel(), unit="ns").values.reshape(
316-
result.shape
317-
)
318-
319294
if result.ndim == 1:
320295
from pandas import Series
321296

@@ -384,14 +359,11 @@ def _center_window(self, result, window) -> np.ndarray:
384359
if self.axis > result.ndim - 1:
385360
raise ValueError("Requested axis is larger then no. of argument dimensions")
386361

387-
offset = _offset(window, True)
362+
offset = calculate_center_offset(window)
388363
if offset > 0:
389-
if isinstance(result, (ABCSeries, ABCDataFrame)):
390-
result = result.slice_shift(-offset, axis=self.axis)
391-
else:
392-
lead_indexer = [slice(None)] * result.ndim
393-
lead_indexer[self.axis] = slice(offset, None)
394-
result = np.copy(result[tuple(lead_indexer)])
364+
lead_indexer = [slice(None)] * result.ndim
365+
lead_indexer[self.axis] = slice(offset, None)
366+
result = np.copy(result[tuple(lead_indexer)])
395367
return result
396368

397369
def _get_roll_func(self, func_name: str) -> Callable:
@@ -424,17 +396,15 @@ def _get_cython_func_type(self, func: str) -> Callable:
424396
return self._get_roll_func(f"{func}_variable")
425397
return partial(self._get_roll_func(f"{func}_fixed"), win=self._get_window())
426398

427-
def _get_window_indexer(
428-
self, index_as_array: Optional[np.ndarray], window: int
429-
) -> BaseIndexer:
399+
def _get_window_indexer(self, window: int) -> BaseIndexer:
430400
"""
431401
Return an indexer class that will compute the window start and end bounds
432402
"""
433403
if isinstance(self.window, BaseIndexer):
434404
return self.window
435405
if self.is_freq_type:
436-
return VariableWindowIndexer(index_array=index_as_array, window_size=window)
437-
return FixedWindowIndexer(index_array=index_as_array, window_size=window)
406+
return VariableWindowIndexer(index_array=self._on.asi8, window_size=window)
407+
return FixedWindowIndexer(window_size=window)
438408

439409
def _apply(
440410
self,
@@ -476,8 +446,7 @@ def _apply(
476446

477447
blocks, obj = self._create_blocks()
478448
block_list = list(blocks)
479-
index_as_array = self._get_index()
480-
window_indexer = self._get_window_indexer(index_as_array, window)
449+
window_indexer = self._get_window_indexer(window)
481450

482451
results = []
483452
exclude: List[Scalar] = []
@@ -498,7 +467,7 @@ def _apply(
498467
continue
499468

500469
# calculation function
501-
offset = _offset(window, center) if center else 0
470+
offset = calculate_center_offset(window) if center else 0
502471
additional_nans = np.array([np.nan] * offset)
503472

504473
if not is_weighted:
@@ -1051,15 +1020,6 @@ def _get_window(
10511020
# GH #15662. `False` makes symmetric window, rather than periodic.
10521021
return sig.get_window(win_type, window, False).astype(float)
10531022

1054-
def _get_weighted_roll_func(
1055-
self, cfunc: Callable, check_minp: Callable, **kwargs
1056-
) -> Callable:
1057-
def func(arg, window, min_periods=None, closed=None):
1058-
minp = check_minp(min_periods, len(window))
1059-
return cfunc(arg, window, minp, **kwargs)
1060-
1061-
return func
1062-
10631023
_agg_see_also_doc = dedent(
10641024
"""
10651025
See Also
@@ -1127,7 +1087,7 @@ def aggregate(self, func, *args, **kwargs):
11271087
def sum(self, *args, **kwargs):
11281088
nv.validate_window_func("sum", args, kwargs)
11291089
window_func = self._get_roll_func("roll_weighted_sum")
1130-
window_func = self._get_weighted_roll_func(window_func, _use_window)
1090+
window_func = get_weighted_roll_func(window_func)
11311091
return self._apply(
11321092
window_func, center=self.center, is_weighted=True, name="sum", **kwargs
11331093
)
@@ -1137,7 +1097,7 @@ def sum(self, *args, **kwargs):
11371097
def mean(self, *args, **kwargs):
11381098
nv.validate_window_func("mean", args, kwargs)
11391099
window_func = self._get_roll_func("roll_weighted_mean")
1140-
window_func = self._get_weighted_roll_func(window_func, _use_window)
1100+
window_func = get_weighted_roll_func(window_func)
11411101
return self._apply(
11421102
window_func, center=self.center, is_weighted=True, name="mean", **kwargs
11431103
)
@@ -1147,7 +1107,7 @@ def mean(self, *args, **kwargs):
11471107
def var(self, ddof=1, *args, **kwargs):
11481108
nv.validate_window_func("var", args, kwargs)
11491109
window_func = partial(self._get_roll_func("roll_weighted_var"), ddof=ddof)
1150-
window_func = self._get_weighted_roll_func(window_func, _use_window)
1110+
window_func = get_weighted_roll_func(window_func)
11511111
kwargs.pop("name", None)
11521112
return self._apply(
11531113
window_func, center=self.center, is_weighted=True, name="var", **kwargs
@@ -1157,7 +1117,7 @@ def var(self, ddof=1, *args, **kwargs):
11571117
@Appender(_shared_docs["std"])
11581118
def std(self, ddof=1, *args, **kwargs):
11591119
nv.validate_window_func("std", args, kwargs)
1160-
return _zsqrt(self.var(ddof=ddof, name="std", **kwargs))
1120+
return zsqrt(self.var(ddof=ddof, name="std", **kwargs))
11611121

11621122

11631123
class _Rolling(_Window):
@@ -1211,8 +1171,6 @@ class _Rolling_and_Expanding(_Rolling):
12111171
def count(self):
12121172

12131173
blocks, obj = self._create_blocks()
1214-
# Validate the index
1215-
self._get_index()
12161174

12171175
window = self._get_window()
12181176
window = min(window, len(obj)) if not self.center else window
@@ -1307,7 +1265,7 @@ def apply(
13071265
kwargs.pop("_level", None)
13081266
kwargs.pop("floor", None)
13091267
window = self._get_window()
1310-
offset = _offset(window, self.center)
1268+
offset = calculate_center_offset(window) if self.center else 0
13111269
if not is_bool(raw):
13121270
raise ValueError("raw parameter must be `True` or `False`")
13131271

@@ -1478,7 +1436,7 @@ def std(self, ddof=1, *args, **kwargs):
14781436
window_func = self._get_cython_func_type("roll_var")
14791437

14801438
def zsqrt_func(values, begin, end, min_periods):
1481-
return _zsqrt(window_func(values, begin, end, min_periods, ddof=ddof))
1439+
return zsqrt(window_func(values, begin, end, min_periods, ddof=ddof))
14821440

14831441
# ddof passed again for compat with groupby.rolling
14841442
return self._apply(

0 commit comments

Comments
 (0)