forked from pandas-dev/pandas
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcommon.py
337 lines (278 loc) · 10.6 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""Common utility functions for rolling operations"""
from collections import defaultdict
from typing import Callable, Optional
import warnings
import numpy as np
from pandas.core.dtypes.common import is_integer
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
import pandas.core.common as com
from pandas.core.generic import _shared_docs
from pandas.core.groupby.base import GroupByMixin
from pandas.core.indexes.api import MultiIndex
_shared_docs = dict(**_shared_docs)
_doc_template = """
Returns
-------
Series or DataFrame
Return type is determined by the caller.
See Also
--------
Series.%(name)s : Series %(name)s.
DataFrame.%(name)s : DataFrame %(name)s.
"""
def _dispatch(name: str, *args, **kwargs):
"""
Dispatch to apply.
"""
def outer(self, *args, **kwargs):
def f(x):
x = self._shallow_copy(x, groupby=self._groupby)
return getattr(x, name)(*args, **kwargs)
return self._groupby.apply(f)
outer.__name__ = name
return outer
class WindowGroupByMixin(GroupByMixin):
"""
Provide the groupby facilities.
"""
def __init__(self, obj, *args, **kwargs):
kwargs.pop("parent", None)
groupby = kwargs.pop("groupby", None)
if groupby is None:
groupby, obj = obj, obj.obj
self._groupby = groupby
self._groupby.mutated = True
self._groupby.grouper.mutated = True
super().__init__(obj, *args, **kwargs)
count = _dispatch("count")
corr = _dispatch("corr", other=None, pairwise=None)
cov = _dispatch("cov", other=None, pairwise=None)
def _apply(
self,
func: Callable,
center: bool,
require_min_periods: int = 0,
floor: int = 1,
is_weighted: bool = False,
name: Optional[str] = None,
**kwargs,
):
"""
Dispatch to apply; we are stripping all of the _apply kwargs and
performing the original function call on the grouped object.
"""
kwargs.pop("floor", None)
# TODO: can we de-duplicate with _dispatch?
def f(x, name=name, *args):
x = self._shallow_copy(x)
if isinstance(name, str):
return getattr(x, name)(*args, **kwargs)
return x.apply(name, *args, **kwargs)
return self._groupby.apply(f)
def _flex_binary_moment(arg1, arg2, f, pairwise=False):
if not (
isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame))
and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame))
):
raise TypeError(
"arguments to moment function must be of type "
"np.ndarray/Series/DataFrame"
)
if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance(
arg2, (np.ndarray, ABCSeries)
):
X, Y = _prep_binary(arg1, arg2)
return f(X, Y)
elif isinstance(arg1, ABCDataFrame):
from pandas import DataFrame
def dataframe_from_int_dict(data, frame_template):
result = DataFrame(data, index=frame_template.index)
if len(result.columns) > 0:
result.columns = frame_template.columns[result.columns]
return result
results = {}
if isinstance(arg2, ABCDataFrame):
if pairwise is False:
if arg1 is arg2:
# special case in order to handle duplicate column names
for i, col in enumerate(arg1.columns):
results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i])
return dataframe_from_int_dict(results, arg1)
else:
if not arg1.columns.is_unique:
raise ValueError("'arg1' columns are not unique")
if not arg2.columns.is_unique:
raise ValueError("'arg2' columns are not unique")
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
X, Y = arg1.align(arg2, join="outer")
X = X + 0 * Y
Y = Y + 0 * X
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
res_columns = arg1.columns.union(arg2.columns)
for col in res_columns:
if col in X and col in Y:
results[col] = f(X[col], Y[col])
return DataFrame(results, index=X.index, columns=res_columns)
elif pairwise is True:
results = defaultdict(dict)
for i, k1 in enumerate(arg1.columns):
for j, k2 in enumerate(arg2.columns):
if j < i and arg2 is arg1:
# Symmetric case
results[i][j] = results[j][i]
else:
results[i][j] = f(
*_prep_binary(arg1.iloc[:, i], arg2.iloc[:, j])
)
from pandas import concat
result_index = arg1.index.union(arg2.index)
if len(result_index):
# construct result frame
result = concat(
[
concat(
[results[i][j] for j, c in enumerate(arg2.columns)],
ignore_index=True,
)
for i, c in enumerate(arg1.columns)
],
ignore_index=True,
axis=1,
)
result.columns = arg1.columns
# set the index and reorder
if arg2.columns.nlevels > 1:
result.index = MultiIndex.from_product(
arg2.columns.levels + [result_index]
)
result = result.reorder_levels([2, 0, 1]).sort_index()
else:
result.index = MultiIndex.from_product(
[range(len(arg2.columns)), range(len(result_index))]
)
result = result.swaplevel(1, 0).sort_index()
result.index = MultiIndex.from_product(
[result_index] + [arg2.columns]
)
else:
# empty result
result = DataFrame(
index=MultiIndex(
levels=[arg1.index, arg2.columns], codes=[[], []]
),
columns=arg2.columns,
dtype="float64",
)
# reset our index names to arg1 names
# reset our column names to arg2 names
# careful not to mutate the original names
result.columns = result.columns.set_names(arg1.columns.names)
result.index = result.index.set_names(
result_index.names + arg2.columns.names
)
return result
else:
raise ValueError("'pairwise' is not True/False")
else:
results = {
i: f(*_prep_binary(arg1.iloc[:, i], arg2))
for i, col in enumerate(arg1.columns)
}
return dataframe_from_int_dict(results, arg1)
else:
return _flex_binary_moment(arg2, arg1, f)
def _get_center_of_mass(comass, span, halflife, alpha):
valid_count = com.count_not_none(comass, span, halflife, alpha)
if valid_count > 1:
raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
# Convert to center of mass; domain checks ensure 0 < alpha <= 1
if comass is not None:
if comass < 0:
raise ValueError("comass must satisfy: comass >= 0")
elif span is not None:
if span < 1:
raise ValueError("span must satisfy: span >= 1")
comass = (span - 1) / 2.0
elif halflife is not None:
if halflife <= 0:
raise ValueError("halflife must satisfy: halflife > 0")
decay = 1 - np.exp(np.log(0.5) / halflife)
comass = 1 / decay - 1
elif alpha is not None:
if alpha <= 0 or alpha > 1:
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
comass = (1.0 - alpha) / alpha
else:
raise ValueError("Must pass one of comass, span, halflife, or alpha")
return float(comass)
def _offset(window, center):
if not is_integer(window):
window = len(window)
offset = (window - 1) / 2.0 if center else 0
try:
return int(offset)
except TypeError:
return offset.astype(int)
def _require_min_periods(p):
def _check_func(minp, window):
if minp is None:
return window
else:
return max(p, minp)
return _check_func
def _use_window(minp, window):
if minp is None:
return window
else:
return minp
def calculate_min_periods(
window: int,
min_periods: Optional[int],
num_values: int,
required_min_periods: int,
floor: int,
) -> int:
"""
Calculates final minimum periods value for rolling aggregations.
Parameters
----------
window : passed window value
min_periods : passed min periods value
num_values : total number of values
required_min_periods : required min periods per aggregation function
floor : required min periods per aggregation function
Returns
-------
min_periods : int
"""
if min_periods is None:
min_periods = window
else:
min_periods = max(required_min_periods, min_periods)
if min_periods > window:
raise ValueError(f"min_periods {min_periods} must be <= window {window}")
elif min_periods > num_values:
min_periods = num_values + 1
elif min_periods < 0:
raise ValueError("min_periods must be >= 0")
return max(min_periods, floor)
def _zsqrt(x):
with np.errstate(all="ignore"):
result = np.sqrt(x)
mask = x < 0
if isinstance(x, ABCDataFrame):
if mask.values.any():
result[mask] = 0
else:
if mask.any():
result[mask] = 0
return result
def _prep_binary(arg1, arg2):
if not isinstance(arg2, type(arg1)):
raise Exception("Input arrays must be of the same type!")
# mask out values, this also makes a common index...
X = arg1 + 0 * arg2
Y = arg2 + 0 * arg1
return X, Y