Skip to content

Commit 315683f

Browse files
bashtageKevin Sheppard
and
Kevin Sheppard
authored
ENH: Improve Roling, Expanding and EWM (#297)
* ENH: IMprove Roling, Expanding and EWM * ENH: Add generic * TST: Add tests for rolling * TST: Add tests and fix types * Improve typing * TYP Further improvements and refactoring Co-authored-by: Kevin Sheppard <[email protected]>
1 parent a3cabb3 commit 315683f

File tree

11 files changed

+735
-204
lines changed

11 files changed

+735
-204
lines changed

pandas-stubs/_typing.pyi

+21-4
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,17 @@ FuncType = Callable[..., Any]
9898
F = TypeVar("F", bound=FuncType)
9999
HashableT = TypeVar("HashableT", bound=Hashable)
100100

101-
AggFuncTypeBase = Union[Callable, str]
102-
AggFuncTypeDict = dict[Hashable, Union[AggFuncTypeBase, list[AggFuncTypeBase]]]
103-
AggFuncType = Union[
101+
AggFuncTypeBase = Union[Callable, str, np.ufunc]
102+
AggFuncTypeDictSeries = dict[Hashable, AggFuncTypeBase]
103+
AggFuncTypeDictFrame = dict[Hashable, Union[AggFuncTypeBase, list[AggFuncTypeBase]]]
104+
AggFuncTypeSeriesToFrame = Union[
105+
list[AggFuncTypeBase],
106+
AggFuncTypeDictSeries,
107+
]
108+
AggFuncTypeFrame = Union[
104109
AggFuncTypeBase,
105110
list[AggFuncTypeBase],
106-
AggFuncTypeDict,
111+
AggFuncTypeDictFrame,
107112
]
108113

109114
num = complex
@@ -264,6 +269,18 @@ FileWriteMode = Literal[
264269
]
265270
ColspaceArgType = str | int | Sequence[int | str] | Mapping[Hashable, str | int]
266271

272+
# Windowing rank methods
273+
WindowingRankType = Literal["average", "min", "max"]
274+
WindowingEngine = Union[Literal["cython", "numba"], None]
275+
276+
class _WindowingNumbaKwargs(TypedDict, total=False):
277+
nopython: bool
278+
nogil: bool
279+
parallel: bool
280+
281+
WindowingEngineKwargs = Union[_WindowingNumbaKwargs, None]
282+
QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"]
283+
267284
class StyleExportDict(TypedDict, total=False):
268285
apply: Any
269286
table_attributes: Any

pandas-stubs/core/frame.pyi

+26-17
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ from pandas.core.indexing import (
3434
)
3535
from pandas.core.resample import Resampler
3636
from pandas.core.series import Series
37+
from pandas.core.window import (
38+
Expanding,
39+
ExponentialMovingWindow,
40+
)
3741
from pandas.core.window.rolling import (
3842
Rolling,
3943
Window,
@@ -43,9 +47,9 @@ import xarray as xr
4347
from pandas._libs.missing import NAType
4448
from pandas._typing import (
4549
S1,
46-
AggFuncType,
4750
AggFuncTypeBase,
48-
AggFuncTypeDict,
51+
AggFuncTypeDictFrame,
52+
AggFuncTypeFrame,
4953
AnyArrayLike,
5054
ArrayLike,
5155
Axes,
@@ -65,6 +69,7 @@ from pandas._typing import (
6569
IndexingInt,
6670
IndexLabel,
6771
IndexType,
72+
IntervalClosedType,
6873
JsonFrameOrient,
6974
Label,
7075
Level,
@@ -74,6 +79,7 @@ from pandas._typing import (
7479
MergeHow,
7580
NaPosition,
7681
ParquetEngine,
82+
QuantileInterpolation,
7783
ReadBuffer,
7884
Renamer,
7985
ReplaceMethod,
@@ -1052,7 +1058,7 @@ class DataFrame(NDFrame, OpsMixin):
10521058
@overload
10531059
def agg(
10541060
self,
1055-
func: list[AggFuncTypeBase] | AggFuncTypeDict = ...,
1061+
func: list[AggFuncTypeBase] | AggFuncTypeDictFrame = ...,
10561062
axis: AxisType = ...,
10571063
**kwargs,
10581064
) -> DataFrame: ...
@@ -1063,13 +1069,13 @@ class DataFrame(NDFrame, OpsMixin):
10631069
@overload
10641070
def aggregate(
10651071
self,
1066-
func: list[AggFuncTypeBase] | AggFuncTypeDict,
1072+
func: list[AggFuncTypeBase] | AggFuncTypeDictFrame,
10671073
axis: AxisType = ...,
10681074
**kwargs,
10691075
) -> DataFrame: ...
10701076
def transform(
10711077
self,
1072-
func: AggFuncType,
1078+
func: AggFuncTypeFrame,
10731079
axis: AxisType = ...,
10741080
*args,
10751081
**kwargs,
@@ -1165,17 +1171,15 @@ class DataFrame(NDFrame, OpsMixin):
11651171
q: float = ...,
11661172
axis: AxisType = ...,
11671173
numeric_only: _bool = ...,
1168-
interpolation: _str
1169-
| Literal["linear", "lower", "higher", "midpoint", "nearest"] = ...,
1174+
interpolation: QuantileInterpolation = ...,
11701175
) -> Series: ...
11711176
@overload
11721177
def quantile(
11731178
self,
11741179
q: list[float] | np.ndarray,
11751180
axis: AxisType = ...,
11761181
numeric_only: _bool = ...,
1177-
interpolation: _str
1178-
| Literal["linear", "lower", "higher", "midpoint", "nearest"] = ...,
1182+
interpolation: QuantileInterpolation = ...,
11791183
) -> DataFrame: ...
11801184
def to_timestamp(
11811185
self,
@@ -1412,8 +1416,13 @@ class DataFrame(NDFrame, OpsMixin):
14121416
adjust: _bool = ...,
14131417
ignore_na: _bool = ...,
14141418
axis: AxisType = ...,
1415-
) -> DataFrame: ...
1416-
def expanding(self, min_periods: int = ..., axis: AxisType = ...): ... # for now
1419+
) -> ExponentialMovingWindow[DataFrame]: ...
1420+
def expanding(
1421+
self,
1422+
min_periods: int = ...,
1423+
axis: AxisType = ...,
1424+
method: Literal["single", "table"] = ...,
1425+
) -> Expanding[DataFrame]: ...
14171426
@overload
14181427
def ffill(
14191428
self,
@@ -1767,21 +1776,21 @@ class DataFrame(NDFrame, OpsMixin):
17671776
center: _bool = ...,
17681777
*,
17691778
win_type: _str,
1770-
on: _str | None = ...,
1779+
on: Hashable | None = ...,
17711780
axis: AxisType = ...,
1772-
closed: _str | None = ...,
1773-
) -> Window: ...
1781+
closed: IntervalClosedType | None = ...,
1782+
) -> Window[DataFrame]: ...
17741783
@overload
17751784
def rolling(
17761785
self,
17771786
window,
17781787
min_periods: int | None = ...,
17791788
center: _bool = ...,
17801789
*,
1781-
on: _str | None = ...,
1790+
on: Hashable | None = ...,
17821791
axis: AxisType = ...,
1783-
closed: _str | None = ...,
1784-
) -> Rolling: ...
1792+
closed: IntervalClosedType | None = ...,
1793+
) -> Rolling[DataFrame]: ...
17851794
def rpow(
17861795
self,
17871796
other,

pandas-stubs/core/groupby/generic.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ from pandas.core.series import Series
2626

2727
from pandas._typing import (
2828
S1,
29-
AggFuncType,
3029
AggFuncTypeBase,
30+
AggFuncTypeFrame,
3131
AxisType,
3232
Level,
3333
ListLike,
@@ -154,8 +154,8 @@ class DataFrameGroupBy(GroupBy):
154154
def apply( # pyright: ignore[reportOverlappingOverload]
155155
self, func: Callable[[Iterable], float], *args, **kwargs
156156
) -> DataFrame: ...
157-
def aggregate(self, arg: AggFuncType = ..., *args, **kwargs) -> DataFrame: ...
158-
def agg(self, arg: AggFuncType = ..., *args, **kwargs) -> DataFrame: ...
157+
def aggregate(self, arg: AggFuncTypeFrame = ..., *args, **kwargs) -> DataFrame: ...
158+
agg = aggregate
159159
def transform(self, func, *args, **kwargs): ...
160160
def filter(
161161
self, func: Callable, dropna: bool = ..., *args, **kwargs

pandas-stubs/core/series.pyi

+21-29
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ from pandas.core.indexing import (
4949
)
5050
from pandas.core.resample import Resampler
5151
from pandas.core.strings import StringMethods
52-
from pandas.core.window import ExponentialMovingWindow
52+
from pandas.core.window import (
53+
Expanding,
54+
ExponentialMovingWindow,
55+
Rolling,
56+
)
5357
from pandas.core.window.rolling import (
5458
Rolling,
5559
Window,
@@ -60,7 +64,8 @@ from pandas._libs.missing import NAType
6064
from pandas._typing import (
6165
S1,
6266
AggFuncTypeBase,
63-
AggFuncTypeDict,
67+
AggFuncTypeDictFrame,
68+
AggFuncTypeSeriesToFrame,
6469
ArrayLike,
6570
Axes,
6671
Axis,
@@ -80,6 +85,7 @@ from pandas._typing import (
8085
MaskType,
8186
MergeHow,
8287
NaPosition,
88+
QuantileInterpolation,
8389
Renamer,
8490
ReplaceMethod,
8591
Scalar,
@@ -454,15 +460,13 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
454460
def quantile(
455461
self,
456462
q: float = ...,
457-
interpolation: _str
458-
| Literal["linear", "lower", "higher", "midpoint", "nearest"] = ...,
463+
interpolation: QuantileInterpolation = ...,
459464
) -> float: ...
460465
@overload
461466
def quantile(
462467
self,
463468
q: _ListLike,
464-
interpolation: _str
465-
| Literal["linear", "lower", "higher", "midpoint", "nearest"] = ...,
469+
interpolation: QuantileInterpolation = ...,
466470
) -> Series[S1]: ...
467471
def corr(
468472
self,
@@ -629,27 +633,12 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
629633
@overload
630634
def aggregate(
631635
self,
632-
func: list[AggFuncTypeBase] | dict[Hashable, AggFuncTypeBase] = ...,
633-
axis: SeriesAxisType = ...,
634-
*args,
635-
**kwargs,
636-
) -> Series[S1]: ...
637-
@overload
638-
def agg(
639-
self,
640-
func: AggFuncTypeBase,
641-
axis: SeriesAxisType = ...,
642-
*args,
643-
**kwargs,
644-
) -> S1: ...
645-
@overload
646-
def agg(
647-
self,
648-
func: list[AggFuncTypeBase] | dict[Hashable, AggFuncTypeBase] = ...,
636+
func: AggFuncTypeSeriesToFrame = ...,
649637
axis: SeriesAxisType = ...,
650638
*args,
651639
**kwargs,
652640
) -> Series[S1]: ...
641+
agg = aggregate
653642
@overload
654643
def transform(
655644
self,
@@ -661,7 +650,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
661650
@overload
662651
def transform(
663652
self,
664-
func: list[AggFuncTypeBase] | AggFuncTypeDict,
653+
func: list[AggFuncTypeBase] | AggFuncTypeDictFrame,
665654
axis: SeriesAxisType = ...,
666655
*args,
667656
**kwargs,
@@ -1322,10 +1311,13 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
13221311
adjust: _bool = ...,
13231312
ignore_na: _bool = ...,
13241313
axis: SeriesAxisType = ...,
1325-
) -> ExponentialMovingWindow: ...
1314+
) -> ExponentialMovingWindow[Series]: ...
13261315
def expanding(
1327-
self, min_periods: int = ..., axis: SeriesAxisType = ...
1328-
) -> DataFrame: ...
1316+
self,
1317+
min_periods: int = ...,
1318+
axis: SeriesAxisType = ...,
1319+
method: Literal["single", "table"] = ...,
1320+
) -> Expanding[Series]: ...
13291321
def floordiv(
13301322
self,
13311323
other: num | _ListLike | Series[S1],
@@ -1528,7 +1520,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
15281520
on: _str | None = ...,
15291521
axis: SeriesAxisType = ...,
15301522
closed: _str | None = ...,
1531-
) -> Window: ...
1523+
) -> Window[Series]: ...
15321524
@overload
15331525
def rolling(
15341526
self,
@@ -1539,7 +1531,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
15391531
on: _str | None = ...,
15401532
axis: SeriesAxisType = ...,
15411533
closed: _str | None = ...,
1542-
) -> Rolling: ...
1534+
) -> Rolling[Series]: ...
15431535
def rpow(
15441536
self,
15451537
other: Series[S1] | Scalar,

0 commit comments

Comments
 (0)