Skip to content

Commit 883adba

Browse files
rhshadrachyehoshuadimarsky
authored andcommitted
ENH: Add numeric_only to window ops (pandas-dev#47265)
* ENH: Add numeric_only to window ops * Fix corr/cov for Series; add tests
1 parent e05b362 commit 883adba

File tree

9 files changed

+659
-87
lines changed

9 files changed

+659
-87
lines changed

doc/source/whatsnew/v1.5.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,9 @@ gained the ``numeric_only`` argument.
654654
- :meth:`.Resampler.sem`
655655
- :meth:`.Resampler.std`
656656
- :meth:`.Resampler.var`
657+
- :meth:`DataFrame.rolling` operations
658+
- :meth:`DataFrame.expanding` operations
659+
- :meth:`DataFrame.ewm` operations
657660

658661
.. _whatsnew_150.deprecations.other:
659662

pandas/core/window/doc.py

+9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ def create_section_header(header: str) -> str:
2929
"""
3030
).replace("\n", "", 1)
3131

32+
kwargs_numeric_only = dedent(
33+
"""
34+
numeric_only : bool, default False
35+
Include only float, int, boolean columns.
36+
37+
.. versionadded:: 1.5.0\n
38+
"""
39+
).replace("\n", "", 1)
40+
3241
args_compat = dedent(
3342
"""
3443
*args

pandas/core/window/ewm.py

+58-13
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
from pandas.util._decorators import doc
2727
from pandas.util._exceptions import find_stack_level
2828

29-
from pandas.core.dtypes.common import is_datetime64_ns_dtype
29+
from pandas.core.dtypes.common import (
30+
is_datetime64_ns_dtype,
31+
is_numeric_dtype,
32+
)
3033
from pandas.core.dtypes.missing import isna
3134

3235
import pandas.core.common as common # noqa: PDF018
@@ -45,6 +48,7 @@
4548
args_compat,
4649
create_section_header,
4750
kwargs_compat,
51+
kwargs_numeric_only,
4852
numba_notes,
4953
template_header,
5054
template_returns,
@@ -518,6 +522,7 @@ def aggregate(self, func, *args, **kwargs):
518522
@doc(
519523
template_header,
520524
create_section_header("Parameters"),
525+
kwargs_numeric_only,
521526
args_compat,
522527
window_agg_numba_parameters(),
523528
kwargs_compat,
@@ -531,7 +536,14 @@ def aggregate(self, func, *args, **kwargs):
531536
aggregation_description="(exponential weighted moment) mean",
532537
agg_method="mean",
533538
)
534-
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
539+
def mean(
540+
self,
541+
numeric_only: bool = False,
542+
*args,
543+
engine=None,
544+
engine_kwargs=None,
545+
**kwargs,
546+
):
535547
if maybe_use_numba(engine):
536548
if self.method == "single":
537549
func = generate_numba_ewm_func
@@ -545,7 +557,7 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
545557
deltas=tuple(self._deltas),
546558
normalize=True,
547559
)
548-
return self._apply(ewm_func)
560+
return self._apply(ewm_func, name="mean")
549561
elif engine in ("cython", None):
550562
if engine_kwargs is not None:
551563
raise ValueError("cython engine does not accept engine_kwargs")
@@ -560,13 +572,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
560572
deltas=deltas,
561573
normalize=True,
562574
)
563-
return self._apply(window_func)
575+
return self._apply(window_func, name="mean", numeric_only=numeric_only)
564576
else:
565577
raise ValueError("engine must be either 'numba' or 'cython'")
566578

567579
@doc(
568580
template_header,
569581
create_section_header("Parameters"),
582+
kwargs_numeric_only,
570583
args_compat,
571584
window_agg_numba_parameters(),
572585
kwargs_compat,
@@ -580,7 +593,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
580593
aggregation_description="(exponential weighted moment) sum",
581594
agg_method="sum",
582595
)
583-
def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
596+
def sum(
597+
self,
598+
numeric_only: bool = False,
599+
*args,
600+
engine=None,
601+
engine_kwargs=None,
602+
**kwargs,
603+
):
584604
if not self.adjust:
585605
raise NotImplementedError("sum is not implemented with adjust=False")
586606
if maybe_use_numba(engine):
@@ -596,7 +616,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
596616
deltas=tuple(self._deltas),
597617
normalize=False,
598618
)
599-
return self._apply(ewm_func)
619+
return self._apply(ewm_func, name="sum")
600620
elif engine in ("cython", None):
601621
if engine_kwargs is not None:
602622
raise ValueError("cython engine does not accept engine_kwargs")
@@ -611,7 +631,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
611631
deltas=deltas,
612632
normalize=False,
613633
)
614-
return self._apply(window_func)
634+
return self._apply(window_func, name="sum", numeric_only=numeric_only)
615635
else:
616636
raise ValueError("engine must be either 'numba' or 'cython'")
617637

@@ -624,6 +644,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
624644
Use a standard estimation bias correction.
625645
"""
626646
).replace("\n", "", 1),
647+
kwargs_numeric_only,
627648
args_compat,
628649
kwargs_compat,
629650
create_section_header("Returns"),
@@ -634,9 +655,18 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
634655
aggregation_description="(exponential weighted moment) standard deviation",
635656
agg_method="std",
636657
)
637-
def std(self, bias: bool = False, *args, **kwargs):
658+
def std(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
638659
nv.validate_window_func("std", args, kwargs)
639-
return zsqrt(self.var(bias=bias, **kwargs))
660+
if (
661+
numeric_only
662+
and self._selected_obj.ndim == 1
663+
and not is_numeric_dtype(self._selected_obj.dtype)
664+
):
665+
# Raise directly so error message says std instead of var
666+
raise NotImplementedError(
667+
f"{type(self).__name__}.std does not implement numeric_only"
668+
)
669+
return zsqrt(self.var(bias=bias, numeric_only=numeric_only, **kwargs))
640670

641671
def vol(self, bias: bool = False, *args, **kwargs):
642672
warnings.warn(
@@ -658,6 +688,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
658688
Use a standard estimation bias correction.
659689
"""
660690
).replace("\n", "", 1),
691+
kwargs_numeric_only,
661692
args_compat,
662693
kwargs_compat,
663694
create_section_header("Returns"),
@@ -668,7 +699,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
668699
aggregation_description="(exponential weighted moment) variance",
669700
agg_method="var",
670701
)
671-
def var(self, bias: bool = False, *args, **kwargs):
702+
def var(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
672703
nv.validate_window_func("var", args, kwargs)
673704
window_func = window_aggregations.ewmcov
674705
wfunc = partial(
@@ -682,7 +713,7 @@ def var(self, bias: bool = False, *args, **kwargs):
682713
def var_func(values, begin, end, min_periods):
683714
return wfunc(values, begin, end, min_periods, values)
684715

685-
return self._apply(var_func)
716+
return self._apply(var_func, name="var", numeric_only=numeric_only)
686717

687718
@doc(
688719
template_header,
@@ -703,6 +734,7 @@ def var_func(values, begin, end, min_periods):
703734
Use a standard estimation bias correction.
704735
"""
705736
).replace("\n", "", 1),
737+
kwargs_numeric_only,
706738
kwargs_compat,
707739
create_section_header("Returns"),
708740
template_returns,
@@ -717,10 +749,13 @@ def cov(
717749
other: DataFrame | Series | None = None,
718750
pairwise: bool | None = None,
719751
bias: bool = False,
752+
numeric_only: bool = False,
720753
**kwargs,
721754
):
722755
from pandas import Series
723756

757+
self._validate_numeric_only("cov", numeric_only)
758+
724759
def cov_func(x, y):
725760
x_array = self._prep_values(x)
726761
y_array = self._prep_values(y)
@@ -752,7 +787,9 @@ def cov_func(x, y):
752787
)
753788
return Series(result, index=x.index, name=x.name)
754789

755-
return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
790+
return self._apply_pairwise(
791+
self._selected_obj, other, pairwise, cov_func, numeric_only
792+
)
756793

757794
@doc(
758795
template_header,
@@ -771,6 +808,7 @@ def cov_func(x, y):
771808
observations will be used.
772809
"""
773810
).replace("\n", "", 1),
811+
kwargs_numeric_only,
774812
kwargs_compat,
775813
create_section_header("Returns"),
776814
template_returns,
@@ -784,10 +822,13 @@ def corr(
784822
self,
785823
other: DataFrame | Series | None = None,
786824
pairwise: bool | None = None,
825+
numeric_only: bool = False,
787826
**kwargs,
788827
):
789828
from pandas import Series
790829

830+
self._validate_numeric_only("corr", numeric_only)
831+
791832
def cov_func(x, y):
792833
x_array = self._prep_values(x)
793834
y_array = self._prep_values(y)
@@ -825,7 +866,9 @@ def _cov(X, Y):
825866
result = cov / zsqrt(x_var * y_var)
826867
return Series(result, index=x.index, name=x.name)
827868

828-
return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
869+
return self._apply_pairwise(
870+
self._selected_obj, other, pairwise, cov_func, numeric_only
871+
)
829872

830873

831874
class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
@@ -921,6 +964,7 @@ def corr(
921964
self,
922965
other: DataFrame | Series | None = None,
923966
pairwise: bool | None = None,
967+
numeric_only: bool = False,
924968
**kwargs,
925969
):
926970
return NotImplementedError
@@ -930,6 +974,7 @@ def cov(
930974
other: DataFrame | Series | None = None,
931975
pairwise: bool | None = None,
932976
bias: bool = False,
977+
numeric_only: bool = False,
933978
**kwargs,
934979
):
935980
return NotImplementedError

0 commit comments

Comments
 (0)