Skip to content

CLN: Simplify rolling.py helper functions #30672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions pandas/core/window/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _flex_binary_moment(arg1, arg2, f, pairwise=False):
if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
arg2, (np.ndarray, ABCSeries)
):
X, Y = _prep_binary(arg1, arg2)
X, Y = prep_binary(arg1, arg2)
return f(X, Y)

elif isinstance(arg1, ABCDataFrame):
Expand Down Expand Up @@ -152,7 +152,7 @@ def dataframe_from_int_dict(data, frame_template):
results[i][j] = results[j][i]
else:
results[i][j] = f(
*_prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
*prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
)

from pandas import concat
Expand Down Expand Up @@ -213,7 +213,7 @@ def dataframe_from_int_dict(data, frame_template):
raise ValueError("'pairwise' is not True/False")
else:
results = {
i: f(*_prep_binary(arg1.iloc[:, i], arg2))
i: f(*prep_binary(arg1.iloc[:, i], arg2))
for i, col in enumerate(arg1.columns)
}
return dataframe_from_int_dict(results, arg1)
Expand Down Expand Up @@ -250,31 +250,10 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


def _offset(window, center):
def calculate_center_offset(window):
if not is_integer(window):
window = len(window)
offset = (window - 1) / 2.0 if center else 0
try:
return int(offset)
except TypeError:
return offset.astype(int)


def _require_min_periods(p):
def _check_func(minp, window):
if minp is None:
return window
else:
return max(p, minp)

return _check_func


def _use_window(minp, window):
if minp is None:
return window
else:
return minp
return int((window - 1) / 2.0)


def calculate_min_periods(
Expand Down Expand Up @@ -312,7 +291,7 @@ def calculate_min_periods(
return max(min_periods, floor)


def _zsqrt(x):
def zsqrt(x):
with np.errstate(all="ignore"):
result = np.sqrt(x)
mask = x < 0
Expand All @@ -327,7 +306,7 @@ def _zsqrt(x):
return result


def _prep_binary(arg1, arg2):
def prep_binary(arg1, arg2):
if not isinstance(arg2, type(arg1)):
raise Exception("Input arrays must be of the same type!")

Expand All @@ -336,3 +315,12 @@ def _prep_binary(arg1, arg2):
Y = arg2 + 0 * arg1

return X, Y


def get_weighted_roll_func(cfunc: Callable) -> Callable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtypes for Callable would be helpful here if you know from looking at this what they are

def func(arg, window, min_periods=None):
if min_periods is None:
min_periods = len(window)
return cfunc(arg, window, min_periods)

return func
13 changes: 9 additions & 4 deletions pandas/core/window/ewm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
from pandas.core.dtypes.generic import ABCDataFrame

from pandas.core.base import DataError
from pandas.core.window.common import _doc_template, _get_center_of_mass, _shared_docs
from pandas.core.window.rolling import _flex_binary_moment, _Rolling, _zsqrt
from pandas.core.window.common import (
_doc_template,
_get_center_of_mass,
_shared_docs,
zsqrt,
)
from pandas.core.window.rolling import _flex_binary_moment, _Rolling

_bias_template = """
Parameters
Expand Down Expand Up @@ -269,7 +274,7 @@ def std(self, bias=False, *args, **kwargs):
Exponential weighted moving stddev.
"""
nv.validate_window_func("std", args, kwargs)
return _zsqrt(self.var(bias=bias, **kwargs))
return zsqrt(self.var(bias=bias, **kwargs))

vol = std

Expand Down Expand Up @@ -390,7 +395,7 @@ def _cov(x, y):
cov = _cov(x_values, y_values)
x_var = _cov(x_values, x_values)
y_var = _cov(y_values, y_values)
corr = cov / _zsqrt(x_var * y_var)
corr = cov / zsqrt(x_var * y_var)
return X._wrap_result(corr)

return _flex_binary_moment(
Expand Down
78 changes: 18 additions & 60 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
is_integer_dtype,
is_list_like,
is_scalar,
is_timedelta64_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.generic import (
Expand All @@ -43,11 +42,11 @@
WindowGroupByMixin,
_doc_template,
_flex_binary_moment,
_offset,
_shared_docs,
_use_window,
_zsqrt,
calculate_center_offset,
calculate_min_periods,
get_weighted_roll_func,
zsqrt,
)
from pandas.core.window.indexers import (
BaseIndexer,
Expand Down Expand Up @@ -252,19 +251,6 @@ def __iter__(self):
url = "https://github.com/pandas-dev/pandas/issues/11704"
raise NotImplementedError(f"See issue #11704 {url}")

def _get_index(self) -> Optional[np.ndarray]:
"""
Return integer representations as an ndarray if index is frequency.

Returns
-------
None or ndarray
"""

if self.is_freq_type:
return self._on.asi8
return None

def _prep_values(self, values: Optional[np.ndarray] = None) -> np.ndarray:
"""Convert input to numpy arrays for Cython routines"""
if values is None:
Expand Down Expand Up @@ -305,17 +291,6 @@ def _wrap_result(self, result, block=None, obj=None):

if isinstance(result, np.ndarray):

# coerce if necessary
if block is not None:
if is_timedelta64_dtype(block.values.dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not reachable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think so. Our test suite includes timedelta64 which doesn't his this branch.

# TODO: do we know what result.dtype is at this point?
# i.e. can we just do an astype?
from pandas import to_timedelta

result = to_timedelta(result.ravel(), unit="ns").values.reshape(
result.shape
)

if result.ndim == 1:
from pandas import Series

Expand Down Expand Up @@ -384,14 +359,11 @@ def _center_window(self, result, window) -> np.ndarray:
if self.axis > result.ndim - 1:
raise ValueError("Requested axis is larger then no. of argument dimensions")

offset = _offset(window, True)
offset = calculate_center_offset(window)
if offset > 0:
if isinstance(result, (ABCSeries, ABCDataFrame)):
result = result.slice_shift(-offset, axis=self.axis)
else:
lead_indexer = [slice(None)] * result.ndim
lead_indexer[self.axis] = slice(offset, None)
result = np.copy(result[tuple(lead_indexer)])
lead_indexer = [slice(None)] * result.ndim
lead_indexer[self.axis] = slice(offset, None)
result = np.copy(result[tuple(lead_indexer)])
return result

def _get_roll_func(self, func_name: str) -> Callable:
Expand Down Expand Up @@ -424,17 +396,15 @@ def _get_cython_func_type(self, func: str) -> Callable:
return self._get_roll_func(f"{func}_variable")
return partial(self._get_roll_func(f"{func}_fixed"), win=self._get_window())

def _get_window_indexer(
self, index_as_array: Optional[np.ndarray], window: int
) -> BaseIndexer:
def _get_window_indexer(self, window: int) -> BaseIndexer:
"""
Return an indexer class that will compute the window start and end bounds
"""
if isinstance(self.window, BaseIndexer):
return self.window
if self.is_freq_type:
return VariableWindowIndexer(index_array=index_as_array, window_size=window)
return FixedWindowIndexer(index_array=index_as_array, window_size=window)
return VariableWindowIndexer(index_array=self._on.asi8, window_size=window)
return FixedWindowIndexer(window_size=window)

def _apply(
self,
Expand Down Expand Up @@ -476,8 +446,7 @@ def _apply(

blocks, obj = self._create_blocks()
block_list = list(blocks)
index_as_array = self._get_index()
window_indexer = self._get_window_indexer(index_as_array, window)
window_indexer = self._get_window_indexer(window)

results = []
exclude: List[Scalar] = []
Expand All @@ -498,7 +467,7 @@ def _apply(
continue

# calculation function
offset = _offset(window, center) if center else 0
offset = calculate_center_offset(window) if center else 0
additional_nans = np.array([np.nan] * offset)

if not is_weighted:
Expand Down Expand Up @@ -1051,15 +1020,6 @@ def _get_window(
# GH #15662. `False` makes symmetric window, rather than periodic.
return sig.get_window(win_type, window, False).astype(float)

def _get_weighted_roll_func(
self, cfunc: Callable, check_minp: Callable, **kwargs
) -> Callable:
def func(arg, window, min_periods=None, closed=None):
minp = check_minp(min_periods, len(window))
return cfunc(arg, window, minp, **kwargs)

return func

_agg_see_also_doc = dedent(
"""
See Also
Expand Down Expand Up @@ -1127,7 +1087,7 @@ def aggregate(self, func, *args, **kwargs):
def sum(self, *args, **kwargs):
nv.validate_window_func("sum", args, kwargs)
window_func = self._get_roll_func("roll_weighted_sum")
window_func = self._get_weighted_roll_func(window_func, _use_window)
window_func = get_weighted_roll_func(window_func)
return self._apply(
window_func, center=self.center, is_weighted=True, name="sum", **kwargs
)
Expand All @@ -1137,7 +1097,7 @@ def sum(self, *args, **kwargs):
def mean(self, *args, **kwargs):
nv.validate_window_func("mean", args, kwargs)
window_func = self._get_roll_func("roll_weighted_mean")
window_func = self._get_weighted_roll_func(window_func, _use_window)
window_func = get_weighted_roll_func(window_func)
return self._apply(
window_func, center=self.center, is_weighted=True, name="mean", **kwargs
)
Expand All @@ -1147,7 +1107,7 @@ def mean(self, *args, **kwargs):
def var(self, ddof=1, *args, **kwargs):
nv.validate_window_func("var", args, kwargs)
window_func = partial(self._get_roll_func("roll_weighted_var"), ddof=ddof)
window_func = self._get_weighted_roll_func(window_func, _use_window)
window_func = get_weighted_roll_func(window_func)
kwargs.pop("name", None)
return self._apply(
window_func, center=self.center, is_weighted=True, name="var", **kwargs
Expand All @@ -1157,7 +1117,7 @@ def var(self, ddof=1, *args, **kwargs):
@Appender(_shared_docs["std"])
def std(self, ddof=1, *args, **kwargs):
nv.validate_window_func("std", args, kwargs)
return _zsqrt(self.var(ddof=ddof, name="std", **kwargs))
return zsqrt(self.var(ddof=ddof, name="std", **kwargs))


class _Rolling(_Window):
Expand Down Expand Up @@ -1211,8 +1171,6 @@ class _Rolling_and_Expanding(_Rolling):
def count(self):

blocks, obj = self._create_blocks()
# Validate the index
self._get_index()

window = self._get_window()
window = min(window, len(obj)) if not self.center else window
Expand Down Expand Up @@ -1307,7 +1265,7 @@ def apply(
kwargs.pop("_level", None)
kwargs.pop("floor", None)
window = self._get_window()
offset = _offset(window, self.center)
offset = calculate_center_offset(window) if self.center else 0
if not is_bool(raw):
raise ValueError("raw parameter must be `True` or `False`")

Expand Down Expand Up @@ -1478,7 +1436,7 @@ def std(self, ddof=1, *args, **kwargs):
window_func = self._get_cython_func_type("roll_var")

def zsqrt_func(values, begin, end, min_periods):
return _zsqrt(window_func(values, begin, end, min_periods, ddof=ddof))
return zsqrt(window_func(values, begin, end, min_periods, ddof=ddof))

# ddof passed again for compat with groupby.rolling
return self._apply(
Expand Down