-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: Allow users to definite their own window bound calculations in rolling #29878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
1c8c24a
1188edd
6b5e894
3a310f6
d1d0775
218395e
46d4a52
c10854d
c237090
1ddc828
a861982
8f482f7
38691c7
9d740d3
d18e954
f06e8e6
7ccbcd0
4e2fd30
c3153d8
89100c4
2704c59
87768ea
6a6d896
2864e95
b16e711
ed08ca3
9eb3022
f358466
0d8cc1f
25a05fe
5d8819f
9194557
e7e1061
09afec4
10c4994
9089f7b
87e391f
7ce1967
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# cython: boundscheck=False, wraparound=False, cdivision=True | ||
|
||
from typing import Tuple | ||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
from numpy cimport ndarray, int64_t | ||
|
@@ -10,64 +10,122 @@ from numpy cimport ndarray, int64_t | |
# These define start/end indexers to compute offsets | ||
|
||
|
||
class FixedWindowIndexer: | ||
""" | ||
create a fixed length window indexer object | ||
that has start & end, that point to offsets in | ||
the index object; these are defined based on the win | ||
arguments | ||
|
||
Parameters | ||
---------- | ||
values: ndarray | ||
values data array | ||
win: int64_t | ||
window size | ||
index: object | ||
index of the values | ||
closed: string | ||
closed behavior | ||
""" | ||
def __init__(self, ndarray values, int64_t win, object closed, object index=None): | ||
class BaseIndexer: | ||
"""Base class for window bounds calculations""" | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
): | ||
""" | ||
Parameters | ||
---------- | ||
**kwargs : | ||
keyword argument that will be available when get_window_bounds is called | ||
""" | ||
self.__dict__.update(kwargs) | ||
|
||
def get_window_bounds( | ||
self, | ||
num_values: int = 0, | ||
window_size: int = 0, | ||
min_periods: Optional[int] = None, | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
center: Optional[bool] = None, | ||
closed: Optional[str] = None, | ||
win_type: Optional[str] = None, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Computes the bounds of a window. | ||
|
||
Parameters | ||
---------- | ||
num_values : int, default 0 | ||
number of values that will be aggregated over | ||
window_size : int, default 0 | ||
the number of rows in a window | ||
min_periods : int, default None | ||
min_periods passed from the top level rolling API | ||
center : bool, default None | ||
center passed from the top level rolling API | ||
closed : str, default None | ||
closed passed from the top level rolling API | ||
win_type : str, default None | ||
win_type passed from the top level rolling API | ||
|
||
Returns | ||
------- | ||
A tuple of ndarray[int64]s, indicating the boundaries of each | ||
window | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class FixedWindowIndexer(BaseIndexer): | ||
"""Creates window boundaries that are of fixed length.""" | ||
|
||
def get_window_bounds(self, | ||
num_values: int = 0, | ||
window_size: int = 0, | ||
min_periods: Optional[int] = None, | ||
center: Optional[bool] = None, | ||
closed: Optional[str] = None, | ||
win_type: Optional[str] = None, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Computes the fixed bounds of a window. | ||
|
||
Parameters | ||
---------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think could share these docstrings |
||
num_values : int, default 0 | ||
number of values that will be aggregated over | ||
window_size : int, default 0 | ||
the number of rows in a window | ||
min_periods : int, default None | ||
min_periods passed from the top level rolling API | ||
center : bool, default None | ||
center passed from the top level rolling API | ||
closed : str, default None | ||
closed passed from the top level rolling API | ||
win_type : str, default None | ||
win_type passed from the top level rolling API | ||
|
||
Returns | ||
------- | ||
A tuple of ndarray[int64]s, indicating the boundaries of each | ||
window | ||
""" | ||
cdef: | ||
ndarray[int64_t, ndim=1] start_s, start_e, end_s, end_e | ||
int64_t N = len(values) | ||
|
||
start_s = np.zeros(win, dtype='int64') | ||
start_e = np.arange(win, N, dtype='int64') - win + 1 | ||
self.start = np.concatenate([start_s, start_e])[:N] | ||
|
||
end_s = np.arange(win, dtype='int64') + 1 | ||
end_e = start_e + win | ||
self.end = np.concatenate([end_s, end_e])[:N] | ||
|
||
def get_window_bounds(self) -> Tuple[np.ndarray, np.ndarray]: | ||
return self.start, self.end | ||
|
||
|
||
class VariableWindowIndexer: | ||
""" | ||
create a variable length window indexer object | ||
that has start & end, that point to offsets in | ||
the index object; these are defined based on the win | ||
arguments | ||
|
||
Parameters | ||
---------- | ||
values: ndarray | ||
values data array | ||
win: int64_t | ||
window size | ||
index: ndarray | ||
index of the values | ||
closed: string | ||
closed behavior | ||
""" | ||
def __init__(self, ndarray values, int64_t win, object closed, ndarray index): | ||
ndarray[int64_t, ndim=1] start, start_s, start_e, end, end_s, end_e | ||
|
||
start_s = np.zeros(window_size, dtype='int64') | ||
start_e = np.arange(window_size, num_values, dtype='int64') - window_size + 1 | ||
start = np.concatenate([start_s, start_e])[:num_values] | ||
|
||
end_s = np.arange(window_size, dtype='int64') + 1 | ||
end_e = start_e + window_size | ||
end = np.concatenate([end_s, end_e])[:num_values] | ||
return start, end | ||
|
||
|
||
class VariableWindowIndexer(BaseIndexer): | ||
"""Creates window boundaries that are of variable length, namely for time series.""" | ||
|
||
@staticmethod | ||
def _get_window_bound( | ||
int64_t num_values, | ||
int64_t window_size, | ||
object min_periods, | ||
object center, | ||
object closed, | ||
object win_type, | ||
const int64_t[:] index | ||
): | ||
cdef: | ||
bint left_closed = False | ||
bint right_closed = False | ||
int64_t N = len(index) | ||
ndarray[int64_t, ndim=1] start, end | ||
int64_t start_bound, end_bound | ||
Py_ssize_t i, j | ||
|
||
# if windows is variable, default is 'right', otherwise default is 'both' | ||
if closed is None: | ||
|
@@ -79,20 +137,9 @@ class VariableWindowIndexer: | |
if closed in ['left', 'both']: | ||
left_closed = True | ||
|
||
self.start, self.end = self.build(index, win, left_closed, right_closed, N) | ||
|
||
@staticmethod | ||
def build(const int64_t[:] index, int64_t win, bint left_closed, | ||
bint right_closed, int64_t N) -> Tuple[np.ndarray, np.ndarray]: | ||
|
||
cdef: | ||
ndarray[int64_t] start, end | ||
int64_t start_bound, end_bound | ||
Py_ssize_t i, j | ||
|
||
start = np.empty(N, dtype='int64') | ||
start = np.empty(num_values, dtype='int64') | ||
start.fill(-1) | ||
end = np.empty(N, dtype='int64') | ||
end = np.empty(num_values, dtype='int64') | ||
end.fill(-1) | ||
|
||
start[0] = 0 | ||
|
@@ -108,9 +155,9 @@ class VariableWindowIndexer: | |
|
||
# start is start of slice interval (including) | ||
# end is end of slice interval (not including) | ||
for i in range(1, N): | ||
for i in range(1, num_values): | ||
end_bound = index[i] | ||
start_bound = index[i] - win | ||
start_bound = index[i] - window_size | ||
|
||
# left endpoint is closed | ||
if left_closed: | ||
|
@@ -136,5 +183,38 @@ class VariableWindowIndexer: | |
end[i] -= 1 | ||
return start, end | ||
|
||
def get_window_bounds(self) -> Tuple[np.ndarray, np.ndarray]: | ||
return self.start, self.end | ||
def get_window_bounds(self, | ||
num_values: int = 0, | ||
window_size: int = 0, | ||
min_periods: Optional[int] = None, | ||
center: Optional[bool] = None, | ||
closed: Optional[str] = None, | ||
win_type: Optional[str] = None, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Computes the variable bounds of a window. | ||
|
||
Parameters | ||
---------- | ||
num_values : int, default 0 | ||
number of values that will be aggregated over | ||
window_size : int, default 0 | ||
the number of rows in a window | ||
min_periods : int, default None | ||
min_periods passed from the top level rolling API | ||
center : bool, default None | ||
center passed from the top level rolling API | ||
closed : str, default None | ||
closed passed from the top level rolling API | ||
win_type : str, default None | ||
win_type passed from the top level rolling API | ||
|
||
Returns | ||
------- | ||
A tuple of ndarray[int64]s, indicating the boundaries of each | ||
window | ||
""" | ||
# We do this since cython doesn't like accessing class attributes in nogil | ||
return self._get_window_bound( | ||
num_values, window_size, min_periods, center, closed, win_type, self.index | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
""" public toolkit API """ | ||
from . import extensions, types # noqa | ||
from . import extensions, indexers, types # noqa |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Public API for Rolling Window Indexers""" | ||
from pandas._libs.window.indexers import BaseIndexer # noqa: F401 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -400,13 +400,15 @@ def _get_cython_func_type(self, func): | |
self._get_roll_func("{}_fixed".format(func)), win=self._get_window() | ||
) | ||
|
||
def _get_window_indexer(self): | ||
def _get_window_indexer(self, index_as_array): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can index_as_array or the return type be annotated? |
||
""" | ||
Return an indexer class that will compute the window start and end bounds | ||
""" | ||
if isinstance(self.window, window_indexers.BaseIndexer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically we don’t actually care if this is a BaseIndexer rather we care that it has a get_windows_bound with the correct signature |
||
return self.window | ||
if self.is_freq_type: | ||
return window_indexers.VariableWindowIndexer | ||
return window_indexers.FixedWindowIndexer | ||
return window_indexers.VariableWindowIndexer(index=index_as_array) | ||
return window_indexers.FixedWindowIndexer() | ||
|
||
def _apply( | ||
self, | ||
|
@@ -445,7 +447,7 @@ def _apply( | |
blocks, obj = self._create_blocks() | ||
block_list = list(blocks) | ||
index_as_array = self._get_index() | ||
window_indexer = self._get_window_indexer() | ||
window_indexer = self._get_window_indexer(index_as_array) | ||
|
||
results = [] | ||
exclude: List[Scalar] = [] | ||
|
@@ -476,9 +478,9 @@ def calc(x): | |
min_periods = calculate_min_periods( | ||
window, self.min_periods, len(x), require_min_periods, floor | ||
) | ||
start, end = window_indexer( | ||
x, window, self.closed, index_as_array | ||
).get_window_bounds() | ||
start, end = window_indexer.get_window_bounds( | ||
num_values=len(x), window_size=window, closed=self.closed | ||
) | ||
return func(x, start, end, min_periods) | ||
|
||
else: | ||
|
@@ -759,13 +761,18 @@ class Window(_Window): | |
|
||
Parameters | ||
---------- | ||
window : int, or offset | ||
window : int, offset, or BaseIndexer subclass | ||
Size of the moving window. This is the number of observations used for | ||
calculating the statistic. Each window will be a fixed size. | ||
|
||
If its an offset then this will be the time period of each window. Each | ||
window will be a variable sized based on the observations included in | ||
the time-period. This is only valid for datetimelike indexes. | ||
|
||
If a BaseIndexer subclass is passed, calculates the window boundaries | ||
based on the defined ``get_window_bounds`` method. Additional rolling | ||
keyword arguments, namely `min_periods`, `center`, `win_type`, and | ||
`closed` will be passed to `get_window_bounds`. | ||
min_periods : int, default None | ||
Minimum number of observations in window required to have a value | ||
(otherwise result is NA). For a window that is specified by an offset, | ||
|
@@ -906,7 +913,7 @@ def validate(self): | |
super().validate() | ||
|
||
window = self.window | ||
if isinstance(window, (list, tuple, np.ndarray)): | ||
if isinstance(window, (list, tuple, np.ndarray, window_indexers.BaseIndexer)): | ||
pass | ||
elif is_integer(window): | ||
if window <= 0: | ||
|
@@ -995,6 +1002,13 @@ def _get_window( | |
|
||
# GH #15662. `False` makes symmetric window, rather than periodic. | ||
return sig.get_window(win_type, window, False).astype(float) | ||
elif isinstance(window, window_indexers.BaseIndexer): | ||
return window.get_window_bounds( | ||
win_type=self.win_type, | ||
min_periods=self.min_periods, | ||
center=self.center, | ||
closed=self.closed, | ||
) | ||
|
||
def _get_weighted_roll_func( | ||
self, cfunc: Callable, check_minp: Callable, **kwargs | ||
|
@@ -1762,6 +1776,8 @@ def validate(self): | |
if self.min_periods is None: | ||
self.min_periods = 1 | ||
|
||
elif isinstance(self.window, window_indexers.BaseIndexer): | ||
pass | ||
elif not is_integer(self.window): | ||
raise ValueError("window must be an integer") | ||
elif self.window < 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yah can we avoid this pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sorry, it's a bit cryptic.
The motivation is to allow a way for
get_window_bounds
to access an external variable that may not be passed inrolling
or the aggregation function by making it accessible as a class attribute.For example:
Could make this pattern in the
__init__
more explicit and add validation.