Skip to content

Commit ba7f884

Browse files
committed
ENH: Add numeric_only to window ops
1 parent 3bf2cb1 commit ba7f884

File tree

8 files changed

+406
-85
lines changed

8 files changed

+406
-85
lines changed

pandas/core/window/doc.py

+8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ def create_section_header(header: str) -> str:
2929
"""
3030
).replace("\n", "", 1)
3131

32+
kwargs_numeric_only = dedent(
33+
"""
34+
numeric_only : bool, default False
35+
Include only float, int, boolean columns. If None, will attempt to use
36+
everything, then use only numeric data.
37+
"""
38+
)
39+
3240
args_compat = dedent(
3341
"""
3442
*args

pandas/core/window/ewm.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
args_compat,
4646
create_section_header,
4747
kwargs_compat,
48+
kwargs_numeric_only,
4849
numba_notes,
4950
template_header,
5051
template_returns,
@@ -518,6 +519,7 @@ def aggregate(self, func, *args, **kwargs):
518519
@doc(
519520
template_header,
520521
create_section_header("Parameters"),
522+
kwargs_numeric_only,
521523
args_compat,
522524
window_agg_numba_parameters(),
523525
kwargs_compat,
@@ -531,7 +533,14 @@ def aggregate(self, func, *args, **kwargs):
531533
aggregation_description="(exponential weighted moment) mean",
532534
agg_method="mean",
533535
)
534-
def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
536+
def mean(
537+
self,
538+
numeric_only: bool = False,
539+
*args,
540+
engine=None,
541+
engine_kwargs=None,
542+
**kwargs,
543+
):
535544
if maybe_use_numba(engine):
536545
if self.method == "single":
537546
func = generate_numba_ewm_func
@@ -560,13 +569,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
560569
deltas=deltas,
561570
normalize=True,
562571
)
563-
return self._apply(window_func)
572+
return self._apply(window_func, numeric_only=numeric_only)
564573
else:
565574
raise ValueError("engine must be either 'numba' or 'cython'")
566575

567576
@doc(
568577
template_header,
569578
create_section_header("Parameters"),
579+
kwargs_numeric_only,
570580
args_compat,
571581
window_agg_numba_parameters(),
572582
kwargs_compat,
@@ -580,7 +590,14 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
580590
aggregation_description="(exponential weighted moment) sum",
581591
agg_method="sum",
582592
)
583-
def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
593+
def sum(
594+
self,
595+
numeric_only: bool = False,
596+
*args,
597+
engine=None,
598+
engine_kwargs=None,
599+
**kwargs,
600+
):
584601
if not self.adjust:
585602
raise NotImplementedError("sum is not implemented with adjust=False")
586603
if maybe_use_numba(engine):
@@ -611,7 +628,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
611628
deltas=deltas,
612629
normalize=False,
613630
)
614-
return self._apply(window_func)
631+
return self._apply(window_func, numeric_only=numeric_only)
615632
else:
616633
raise ValueError("engine must be either 'numba' or 'cython'")
617634

@@ -624,6 +641,7 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
624641
Use a standard estimation bias correction.
625642
"""
626643
).replace("\n", "", 1),
644+
kwargs_numeric_only,
627645
args_compat,
628646
kwargs_compat,
629647
create_section_header("Returns"),
@@ -634,9 +652,9 @@ def sum(self, *args, engine=None, engine_kwargs=None, **kwargs):
634652
aggregation_description="(exponential weighted moment) standard deviation",
635653
agg_method="std",
636654
)
637-
def std(self, bias: bool = False, *args, **kwargs):
655+
def std(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
638656
nv.validate_window_func("std", args, kwargs)
639-
return zsqrt(self.var(bias=bias, **kwargs))
657+
return zsqrt(self.var(bias=bias, numeric_only=numeric_only, **kwargs))
640658

641659
def vol(self, bias: bool = False, *args, **kwargs):
642660
warnings.warn(
@@ -658,6 +676,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
658676
Use a standard estimation bias correction.
659677
"""
660678
).replace("\n", "", 1),
679+
kwargs_numeric_only,
661680
args_compat,
662681
kwargs_compat,
663682
create_section_header("Returns"),
@@ -668,7 +687,7 @@ def vol(self, bias: bool = False, *args, **kwargs):
668687
aggregation_description="(exponential weighted moment) variance",
669688
agg_method="var",
670689
)
671-
def var(self, bias: bool = False, *args, **kwargs):
690+
def var(self, bias: bool = False, numeric_only: bool = False, *args, **kwargs):
672691
nv.validate_window_func("var", args, kwargs)
673692
window_func = window_aggregations.ewmcov
674693
wfunc = partial(
@@ -682,7 +701,7 @@ def var(self, bias: bool = False, *args, **kwargs):
682701
def var_func(values, begin, end, min_periods):
683702
return wfunc(values, begin, end, min_periods, values)
684703

685-
return self._apply(var_func)
704+
return self._apply(var_func, numeric_only=numeric_only)
686705

687706
@doc(
688707
template_header,
@@ -703,6 +722,7 @@ def var_func(values, begin, end, min_periods):
703722
Use a standard estimation bias correction.
704723
"""
705724
).replace("\n", "", 1),
725+
kwargs_numeric_only,
706726
kwargs_compat,
707727
create_section_header("Returns"),
708728
template_returns,
@@ -717,6 +737,7 @@ def cov(
717737
other: DataFrame | Series | None = None,
718738
pairwise: bool | None = None,
719739
bias: bool = False,
740+
numeric_only: bool = False,
720741
**kwargs,
721742
):
722743
from pandas import Series
@@ -752,7 +773,9 @@ def cov_func(x, y):
752773
)
753774
return Series(result, index=x.index, name=x.name)
754775

755-
return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
776+
return self._apply_pairwise(
777+
self._selected_obj, other, pairwise, cov_func, numeric_only
778+
)
756779

757780
@doc(
758781
template_header,
@@ -771,6 +794,7 @@ def cov_func(x, y):
771794
observations will be used.
772795
"""
773796
).replace("\n", "", 1),
797+
kwargs_numeric_only,
774798
kwargs_compat,
775799
create_section_header("Returns"),
776800
template_returns,
@@ -784,6 +808,7 @@ def corr(
784808
self,
785809
other: DataFrame | Series | None = None,
786810
pairwise: bool | None = None,
811+
numeric_only: bool = False,
787812
**kwargs,
788813
):
789814
from pandas import Series
@@ -825,7 +850,9 @@ def _cov(X, Y):
825850
result = cov / zsqrt(x_var * y_var)
826851
return Series(result, index=x.index, name=x.name)
827852

828-
return self._apply_pairwise(self._selected_obj, other, pairwise, cov_func)
853+
return self._apply_pairwise(
854+
self._selected_obj, other, pairwise, cov_func, numeric_only
855+
)
829856

830857

831858
class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
@@ -921,6 +948,7 @@ def corr(
921948
self,
922949
other: DataFrame | Series | None = None,
923950
pairwise: bool | None = None,
951+
numeric_only: bool = False,
924952
**kwargs,
925953
):
926954
return NotImplementedError
@@ -930,6 +958,7 @@ def cov(
930958
other: DataFrame | Series | None = None,
931959
pairwise: bool | None = None,
932960
bias: bool = False,
961+
numeric_only: bool = False,
933962
**kwargs,
934963
):
935964
return NotImplementedError

0 commit comments

Comments
 (0)