|
| 1 | +"""Common utility functions for rolling operations""" |
| 2 | +from collections import defaultdict |
| 3 | +import warnings |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from pandas.core.dtypes.common import is_integer |
| 8 | +from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries |
| 9 | + |
| 10 | +import pandas.core.common as com |
| 11 | +from pandas.core.generic import _shared_docs |
| 12 | +from pandas.core.groupby.base import GroupByMixin |
| 13 | +from pandas.core.index import MultiIndex |
| 14 | + |
| 15 | +_shared_docs = dict(**_shared_docs) |
| 16 | +_doc_template = """ |
| 17 | + Returns |
| 18 | + ------- |
| 19 | + Series or DataFrame |
| 20 | + Return type is determined by the caller. |
| 21 | +
|
| 22 | + See Also |
| 23 | + -------- |
| 24 | + Series.%(name)s : Series %(name)s. |
| 25 | + DataFrame.%(name)s : DataFrame %(name)s. |
| 26 | +""" |
| 27 | + |
| 28 | + |
| 29 | +class _GroupByMixin(GroupByMixin): |
| 30 | + """ |
| 31 | + Provide the groupby facilities. |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__(self, obj, *args, **kwargs): |
| 35 | + parent = kwargs.pop("parent", None) # noqa |
| 36 | + groupby = kwargs.pop("groupby", None) |
| 37 | + if groupby is None: |
| 38 | + groupby, obj = obj, obj.obj |
| 39 | + self._groupby = groupby |
| 40 | + self._groupby.mutated = True |
| 41 | + self._groupby.grouper.mutated = True |
| 42 | + super().__init__(obj, *args, **kwargs) |
| 43 | + |
| 44 | + count = GroupByMixin._dispatch("count") |
| 45 | + corr = GroupByMixin._dispatch("corr", other=None, pairwise=None) |
| 46 | + cov = GroupByMixin._dispatch("cov", other=None, pairwise=None) |
| 47 | + |
| 48 | + def _apply( |
| 49 | + self, func, name=None, window=None, center=None, check_minp=None, **kwargs |
| 50 | + ): |
| 51 | + """ |
| 52 | + Dispatch to apply; we are stripping all of the _apply kwargs and |
| 53 | + performing the original function call on the grouped object. |
| 54 | + """ |
| 55 | + |
| 56 | + def f(x, name=name, *args): |
| 57 | + x = self._shallow_copy(x) |
| 58 | + |
| 59 | + if isinstance(name, str): |
| 60 | + return getattr(x, name)(*args, **kwargs) |
| 61 | + |
| 62 | + return x.apply(name, *args, **kwargs) |
| 63 | + |
| 64 | + return self._groupby.apply(f) |
| 65 | + |
| 66 | + |
| 67 | +def _flex_binary_moment(arg1, arg2, f, pairwise=False): |
| 68 | + |
| 69 | + if not ( |
| 70 | + isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame)) |
| 71 | + and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame)) |
| 72 | + ): |
| 73 | + raise TypeError( |
| 74 | + "arguments to moment function must be of type " |
| 75 | + "np.ndarray/Series/DataFrame" |
| 76 | + ) |
| 77 | + |
| 78 | + if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance( |
| 79 | + arg2, (np.ndarray, ABCSeries) |
| 80 | + ): |
| 81 | + X, Y = _prep_binary(arg1, arg2) |
| 82 | + return f(X, Y) |
| 83 | + |
| 84 | + elif isinstance(arg1, ABCDataFrame): |
| 85 | + from pandas import DataFrame |
| 86 | + |
| 87 | + def dataframe_from_int_dict(data, frame_template): |
| 88 | + result = DataFrame(data, index=frame_template.index) |
| 89 | + if len(result.columns) > 0: |
| 90 | + result.columns = frame_template.columns[result.columns] |
| 91 | + return result |
| 92 | + |
| 93 | + results = {} |
| 94 | + if isinstance(arg2, ABCDataFrame): |
| 95 | + if pairwise is False: |
| 96 | + if arg1 is arg2: |
| 97 | + # special case in order to handle duplicate column names |
| 98 | + for i, col in enumerate(arg1.columns): |
| 99 | + results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i]) |
| 100 | + return dataframe_from_int_dict(results, arg1) |
| 101 | + else: |
| 102 | + if not arg1.columns.is_unique: |
| 103 | + raise ValueError("'arg1' columns are not unique") |
| 104 | + if not arg2.columns.is_unique: |
| 105 | + raise ValueError("'arg2' columns are not unique") |
| 106 | + with warnings.catch_warnings(record=True): |
| 107 | + warnings.simplefilter("ignore", RuntimeWarning) |
| 108 | + X, Y = arg1.align(arg2, join="outer") |
| 109 | + X = X + 0 * Y |
| 110 | + Y = Y + 0 * X |
| 111 | + |
| 112 | + with warnings.catch_warnings(record=True): |
| 113 | + warnings.simplefilter("ignore", RuntimeWarning) |
| 114 | + res_columns = arg1.columns.union(arg2.columns) |
| 115 | + for col in res_columns: |
| 116 | + if col in X and col in Y: |
| 117 | + results[col] = f(X[col], Y[col]) |
| 118 | + return DataFrame(results, index=X.index, columns=res_columns) |
| 119 | + elif pairwise is True: |
| 120 | + results = defaultdict(dict) |
| 121 | + for i, k1 in enumerate(arg1.columns): |
| 122 | + for j, k2 in enumerate(arg2.columns): |
| 123 | + if j < i and arg2 is arg1: |
| 124 | + # Symmetric case |
| 125 | + results[i][j] = results[j][i] |
| 126 | + else: |
| 127 | + results[i][j] = f( |
| 128 | + *_prep_binary(arg1.iloc[:, i], arg2.iloc[:, j]) |
| 129 | + ) |
| 130 | + |
| 131 | + from pandas import concat |
| 132 | + |
| 133 | + result_index = arg1.index.union(arg2.index) |
| 134 | + if len(result_index): |
| 135 | + |
| 136 | + # construct result frame |
| 137 | + result = concat( |
| 138 | + [ |
| 139 | + concat( |
| 140 | + [results[i][j] for j, c in enumerate(arg2.columns)], |
| 141 | + ignore_index=True, |
| 142 | + ) |
| 143 | + for i, c in enumerate(arg1.columns) |
| 144 | + ], |
| 145 | + ignore_index=True, |
| 146 | + axis=1, |
| 147 | + ) |
| 148 | + result.columns = arg1.columns |
| 149 | + |
| 150 | + # set the index and reorder |
| 151 | + if arg2.columns.nlevels > 1: |
| 152 | + result.index = MultiIndex.from_product( |
| 153 | + arg2.columns.levels + [result_index] |
| 154 | + ) |
| 155 | + result = result.reorder_levels([2, 0, 1]).sort_index() |
| 156 | + else: |
| 157 | + result.index = MultiIndex.from_product( |
| 158 | + [range(len(arg2.columns)), range(len(result_index))] |
| 159 | + ) |
| 160 | + result = result.swaplevel(1, 0).sort_index() |
| 161 | + result.index = MultiIndex.from_product( |
| 162 | + [result_index] + [arg2.columns] |
| 163 | + ) |
| 164 | + else: |
| 165 | + |
| 166 | + # empty result |
| 167 | + result = DataFrame( |
| 168 | + index=MultiIndex( |
| 169 | + levels=[arg1.index, arg2.columns], codes=[[], []] |
| 170 | + ), |
| 171 | + columns=arg2.columns, |
| 172 | + dtype="float64", |
| 173 | + ) |
| 174 | + |
| 175 | + # reset our index names to arg1 names |
| 176 | + # reset our column names to arg2 names |
| 177 | + # careful not to mutate the original names |
| 178 | + result.columns = result.columns.set_names(arg1.columns.names) |
| 179 | + result.index = result.index.set_names( |
| 180 | + result_index.names + arg2.columns.names |
| 181 | + ) |
| 182 | + |
| 183 | + return result |
| 184 | + |
| 185 | + else: |
| 186 | + raise ValueError("'pairwise' is not True/False") |
| 187 | + else: |
| 188 | + results = { |
| 189 | + i: f(*_prep_binary(arg1.iloc[:, i], arg2)) |
| 190 | + for i, col in enumerate(arg1.columns) |
| 191 | + } |
| 192 | + return dataframe_from_int_dict(results, arg1) |
| 193 | + |
| 194 | + else: |
| 195 | + return _flex_binary_moment(arg2, arg1, f) |
| 196 | + |
| 197 | + |
| 198 | +def _get_center_of_mass(comass, span, halflife, alpha): |
| 199 | + valid_count = com.count_not_none(comass, span, halflife, alpha) |
| 200 | + if valid_count > 1: |
| 201 | + raise ValueError("comass, span, halflife, and alpha are mutually exclusive") |
| 202 | + |
| 203 | + # Convert to center of mass; domain checks ensure 0 < alpha <= 1 |
| 204 | + if comass is not None: |
| 205 | + if comass < 0: |
| 206 | + raise ValueError("comass must satisfy: comass >= 0") |
| 207 | + elif span is not None: |
| 208 | + if span < 1: |
| 209 | + raise ValueError("span must satisfy: span >= 1") |
| 210 | + comass = (span - 1) / 2.0 |
| 211 | + elif halflife is not None: |
| 212 | + if halflife <= 0: |
| 213 | + raise ValueError("halflife must satisfy: halflife > 0") |
| 214 | + decay = 1 - np.exp(np.log(0.5) / halflife) |
| 215 | + comass = 1 / decay - 1 |
| 216 | + elif alpha is not None: |
| 217 | + if alpha <= 0 or alpha > 1: |
| 218 | + raise ValueError("alpha must satisfy: 0 < alpha <= 1") |
| 219 | + comass = (1.0 - alpha) / alpha |
| 220 | + else: |
| 221 | + raise ValueError("Must pass one of comass, span, halflife, or alpha") |
| 222 | + |
| 223 | + return float(comass) |
| 224 | + |
| 225 | + |
| 226 | +def _offset(window, center): |
| 227 | + if not is_integer(window): |
| 228 | + window = len(window) |
| 229 | + offset = (window - 1) / 2.0 if center else 0 |
| 230 | + try: |
| 231 | + return int(offset) |
| 232 | + except TypeError: |
| 233 | + return offset.astype(int) |
| 234 | + |
| 235 | + |
| 236 | +def _require_min_periods(p): |
| 237 | + def _check_func(minp, window): |
| 238 | + if minp is None: |
| 239 | + return window |
| 240 | + else: |
| 241 | + return max(p, minp) |
| 242 | + |
| 243 | + return _check_func |
| 244 | + |
| 245 | + |
| 246 | +def _use_window(minp, window): |
| 247 | + if minp is None: |
| 248 | + return window |
| 249 | + else: |
| 250 | + return minp |
| 251 | + |
| 252 | + |
| 253 | +def _zsqrt(x): |
| 254 | + with np.errstate(all="ignore"): |
| 255 | + result = np.sqrt(x) |
| 256 | + mask = x < 0 |
| 257 | + |
| 258 | + if isinstance(x, ABCDataFrame): |
| 259 | + if mask.values.any(): |
| 260 | + result[mask] = 0 |
| 261 | + else: |
| 262 | + if mask.any(): |
| 263 | + result[mask] = 0 |
| 264 | + |
| 265 | + return result |
| 266 | + |
| 267 | + |
| 268 | +def _prep_binary(arg1, arg2): |
| 269 | + if not isinstance(arg2, type(arg1)): |
| 270 | + raise Exception("Input arrays must be of the same type!") |
| 271 | + |
| 272 | + # mask out values, this also makes a common index... |
| 273 | + X = arg1 + 0 * arg2 |
| 274 | + Y = arg2 + 0 * arg1 |
| 275 | + |
| 276 | + return X, Y |
0 commit comments