|
3 | 3 | import datetime
|
4 | 4 | from functools import partial
|
5 | 5 | from textwrap import dedent
|
6 |
| -from typing import TYPE_CHECKING, Optional, Union |
| 6 | +from typing import Optional, Union |
7 | 7 | import warnings
|
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 |
|
11 | 11 | from pandas._libs.tslibs import Timedelta
|
12 | 12 | import pandas._libs.window.aggregations as window_aggregations
|
13 |
| -from pandas._typing import FrameOrSeries, TimedeltaConvertibleTypes |
| 13 | +from pandas._typing import FrameOrSeries, FrameOrSeriesUnion, TimedeltaConvertibleTypes |
14 | 14 | from pandas.compat.numpy import function as nv
|
15 | 15 | from pandas.util._decorators import doc
|
16 | 16 |
|
|
19 | 19 |
|
20 | 20 | import pandas.core.common as common
|
21 | 21 | from pandas.core.util.numba_ import maybe_use_numba
|
22 |
| -from pandas.core.window.common import flex_binary_moment, zsqrt |
| 22 | +from pandas.core.window.common import zsqrt |
23 | 23 | from pandas.core.window.doc import (
|
24 | 24 | _shared_docs,
|
25 | 25 | args_compat,
|
|
35 | 35 | GroupbyIndexer,
|
36 | 36 | )
|
37 | 37 | from pandas.core.window.numba_ import generate_numba_groupby_ewma_func
|
38 |
| -from pandas.core.window.rolling import BaseWindow, BaseWindowGroupby, dispatch |
39 |
| - |
40 |
| -if TYPE_CHECKING: |
41 |
| - from pandas import Series |
| 38 | +from pandas.core.window.rolling import BaseWindow, BaseWindowGroupby |
42 | 39 |
|
43 | 40 |
|
44 | 41 | def get_center_of_mass(
|
@@ -74,13 +71,20 @@ def get_center_of_mass(
|
74 | 71 | return float(comass)
|
75 | 72 |
|
76 | 73 |
|
77 |
| -def wrap_result(obj: Series, result: np.ndarray) -> Series: |
| 74 | +def dispatch(name: str, *args, **kwargs): |
78 | 75 | """
|
79 |
| - Wrap a single 1D result. |
| 76 | + Dispatch to groupby apply. |
80 | 77 | """
|
81 |
| - obj = obj._selected_obj |
82 | 78 |
|
83 |
| - return obj._constructor(result, obj.index, name=obj.name) |
| 79 | + def outer(self, *args, **kwargs): |
| 80 | + def f(x): |
| 81 | + x = self._shallow_copy(x, groupby=self._groupby) |
| 82 | + return getattr(x, name)(*args, **kwargs) |
| 83 | + |
| 84 | + return self._groupby.apply(f) |
| 85 | + |
| 86 | + outer.__name__ = name |
| 87 | + return outer |
84 | 88 |
|
85 | 89 |
|
86 | 90 | class ExponentialMovingWindow(BaseWindow):
|
@@ -443,36 +447,30 @@ def var_func(values, begin, end, min_periods):
|
443 | 447 | )
|
444 | 448 | def cov(
|
445 | 449 | self,
|
446 |
| - other: Optional[Union[np.ndarray, FrameOrSeries]] = None, |
| 450 | + other: Optional[FrameOrSeriesUnion] = None, |
447 | 451 | pairwise: Optional[bool] = None,
|
448 | 452 | bias: bool = False,
|
449 | 453 | **kwargs,
|
450 | 454 | ):
|
451 |
| - if other is None: |
452 |
| - other = self._selected_obj |
453 |
| - # only default unset |
454 |
| - pairwise = True if pairwise is None else pairwise |
455 |
| - other = self._shallow_copy(other) |
456 |
| - |
457 |
| - def _get_cov(X, Y): |
458 |
| - X = self._shallow_copy(X) |
459 |
| - Y = self._shallow_copy(Y) |
460 |
| - cov = window_aggregations.ewmcov( |
461 |
| - X._prep_values(), |
| 455 | + from pandas import Series |
| 456 | + |
| 457 | + def cov_func(x, y): |
| 458 | + x_array = self._prep_values(x) |
| 459 | + y_array = self._prep_values(y) |
| 460 | + result = window_aggregations.ewmcov( |
| 461 | + x_array, |
462 | 462 | np.array([0], dtype=np.int64),
|
463 | 463 | np.array([0], dtype=np.int64),
|
464 | 464 | self.min_periods,
|
465 |
| - Y._prep_values(), |
| 465 | + y_array, |
466 | 466 | self.com,
|
467 | 467 | self.adjust,
|
468 | 468 | self.ignore_na,
|
469 | 469 | bias,
|
470 | 470 | )
|
471 |
| - return wrap_result(X, cov) |
| 471 | + return Series(result, index=x.index, name=x.name) |
472 | 472 |
|
473 |
| - return flex_binary_moment( |
474 |
| - self._selected_obj, other._selected_obj, _get_cov, pairwise=bool(pairwise) |
475 |
| - ) |
| 473 | + return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func) |
476 | 474 |
|
477 | 475 | @doc(
|
478 | 476 | template_header,
|
@@ -502,45 +500,37 @@ def _get_cov(X, Y):
|
502 | 500 | )
|
503 | 501 | def corr(
|
504 | 502 | self,
|
505 |
| - other: Optional[Union[np.ndarray, FrameOrSeries]] = None, |
| 503 | + other: Optional[FrameOrSeriesUnion] = None, |
506 | 504 | pairwise: Optional[bool] = None,
|
507 | 505 | **kwargs,
|
508 | 506 | ):
|
509 |
| - if other is None: |
510 |
| - other = self._selected_obj |
511 |
| - # only default unset |
512 |
| - pairwise = True if pairwise is None else pairwise |
513 |
| - other = self._shallow_copy(other) |
| 507 | + from pandas import Series |
514 | 508 |
|
515 |
| - def _get_corr(X, Y): |
516 |
| - X = self._shallow_copy(X) |
517 |
| - Y = self._shallow_copy(Y) |
| 509 | + def cov_func(x, y): |
| 510 | + x_array = self._prep_values(x) |
| 511 | + y_array = self._prep_values(y) |
518 | 512 |
|
519 |
| - def _cov(x, y): |
| 513 | + def _cov(X, Y): |
520 | 514 | return window_aggregations.ewmcov(
|
521 |
| - x, |
| 515 | + X, |
522 | 516 | np.array([0], dtype=np.int64),
|
523 | 517 | np.array([0], dtype=np.int64),
|
524 | 518 | self.min_periods,
|
525 |
| - y, |
| 519 | + Y, |
526 | 520 | self.com,
|
527 | 521 | self.adjust,
|
528 | 522 | self.ignore_na,
|
529 | 523 | 1,
|
530 | 524 | )
|
531 | 525 |
|
532 |
| - x_values = X._prep_values() |
533 |
| - y_values = Y._prep_values() |
534 | 526 | with np.errstate(all="ignore"):
|
535 |
| - cov = _cov(x_values, y_values) |
536 |
| - x_var = _cov(x_values, x_values) |
537 |
| - y_var = _cov(y_values, y_values) |
538 |
| - corr = cov / zsqrt(x_var * y_var) |
539 |
| - return wrap_result(X, corr) |
540 |
| - |
541 |
| - return flex_binary_moment( |
542 |
| - self._selected_obj, other._selected_obj, _get_corr, pairwise=bool(pairwise) |
543 |
| - ) |
| 527 | + cov = _cov(x_array, y_array) |
| 528 | + x_var = _cov(x_array, x_array) |
| 529 | + y_var = _cov(y_array, y_array) |
| 530 | + result = cov / zsqrt(x_var * y_var) |
| 531 | + return Series(result, index=x.index, name=x.name) |
| 532 | + |
| 533 | + return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func) |
544 | 534 |
|
545 | 535 |
|
546 | 536 | class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
|
|
0 commit comments