diff --git a/pandas-stubs/core/groupby/base.pyi b/pandas-stubs/core/groupby/base.pyi index f56b6a32..97faad2b 100644 --- a/pandas-stubs/core/groupby/base.pyi +++ b/pandas-stubs/core/groupby/base.pyi @@ -1,7 +1,56 @@ from collections.abc import Hashable import dataclasses +from typing import ( + Literal, + TypeAlias, +) @dataclasses.dataclass(order=True, frozen=True) class OutputKey: label: Hashable position: int + +ReductionKernelType: TypeAlias = Literal[ + "all", + "any", + "corrwith", + "count", + "first", + "idxmax", + "idxmin", + "last", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + # as long as `quantile`'s signature accepts only + # a single quantile value, it's a reduction. + # GH#27526 might change that. + "quantile", + "sem", + "size", + "skew", + "std", + "sum", + "var", +] + +TransformationKernelType: TypeAlias = Literal[ + "bfill", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "fillna", + "ngroup", + "pct_change", + "rank", + "shift", +] + +TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index f618a592..920c962f 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -7,6 +7,7 @@ from collections.abc import ( ) from typing import ( Any, + Concatenate, Generic, Literal, NamedTuple, @@ -18,11 +19,15 @@ from typing import ( from matplotlib.axes import Axes as PlotAxes import numpy as np from pandas.core.frame import DataFrame +from pandas.core.groupby.base import TransformReductionListType from pandas.core.groupby.groupby import ( GroupBy, GroupByPlot, ) -from pandas.core.series import Series +from pandas.core.series import ( + Series, + UnknownSeries, +) from typing_extensions import ( Self, TypeAlias, @@ -31,6 +36,7 @@ from typing_extensions import ( from pandas._libs.tslibs.timestamps import Timestamp from pandas._typing import ( S1, + S2, AggFuncTypeBase, AggFuncTypeFrame, ByT, @@ -40,6 +46,7 @@ from pandas._typing import ( Level, ListLike, NsmallestNlargestKeep, + P, Scalar, TakeIndexer, WindowingEngine, @@ -53,10 +60,30 @@ class NamedAgg(NamedTuple): aggfunc: AggScalar class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): + @overload + def aggregate( + self, + func: Callable[Concatenate[Series[S1], P], S2], + /, + *args, + engine: WindowingEngine = ..., + engine_kwargs: WindowingEngineKwargs = ..., + **kwargs, + ) -> Series[S2]: ... + @overload + def aggregate( + self, + func: Callable[[Series], S2], + *args, + engine: WindowingEngine = ..., + engine_kwargs: WindowingEngineKwargs = ..., + **kwargs, + ) -> Series[S2]: ... @overload def aggregate( self, func: list[AggFuncTypeBase], + /, *args, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., @@ -66,20 +93,34 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): def aggregate( self, func: AggFuncTypeBase | None = ..., + /, *args, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., **kwargs, - ) -> Series: ... + ) -> UnknownSeries: ... agg = aggregate + @overload def transform( self, - func: Callable | str, - *args, + func: Callable[Concatenate[Series[S1], P], Series[S2]], + /, + *args: Any, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., - **kwargs, - ) -> Series: ... + **kwargs: Any, + ) -> Series[S2]: ... + @overload + def transform( + self, + func: Callable, + *args: Any, + **kwargs: Any, + ) -> UnknownSeries: ... + @overload + def transform( + self, func: TransformReductionListType, *args, **kwargs + ) -> UnknownSeries: ... def filter( self, func: Callable | str, dropna: bool = ..., *args, **kwargs ) -> Series: ... @@ -206,13 +247,25 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): **kwargs, ) -> DataFrame: ... agg = aggregate + @overload def transform( self, - func: Callable | str, - *args, + func: Callable[Concatenate[DataFrame, P], DataFrame], + *args: Any, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., - **kwargs, + **kwargs: Any, + ) -> DataFrame: ... + @overload + def transform( + self, + func: Callable, + *args: Any, + **kwargs: Any, + ) -> DataFrame: ... + @overload + def transform( + self, func: TransformReductionListType, *args, **kwargs ) -> DataFrame: ... def filter( self, func: Callable, dropna: bool = ..., *args, **kwargs diff --git a/tests/test_series.py b/tests/test_series.py index 57051997..161aedf4 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1078,12 +1078,63 @@ def test_types_groupby_agg() -> None: r"The provided callable is currently using", upper="2.2.99", ): - check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series) + + def sum_sr(s: pd.Series[int]) -> int: + # type of `sum` not well inferred by mypy + return s.sum() + + check( + assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"), + pd.Series, + np.integer, + ) check( assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame ) +def test_types_groupby_transform() -> None: + s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) + + def transform_func( + x: pd.Series[int], pos_arg: bool, kw_arg: str + ) -> pd.Series[float]: + return x / (2.0 if pos_arg else 1.0) + + check( + assert_type( + s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"), + "pd.Series[float]", + ), + pd.Series, + float, + ) + check( + assert_type( + s.groupby(lambda x: x).transform( + transform_func, True, engine="cython", kw_arg="foo" + ), + "pd.Series[float]", + ), + pd.Series, + float, + ) + check( + assert_type( + s.groupby(lambda x: x).transform("mean"), + "pd.Series", + ), + pd.Series, + ) + check( + assert_type( + s.groupby(lambda x: x).transform("first"), + "pd.Series", + ), + pd.Series, + ) + + def test_types_groupby_aggregate() -> None: s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series) @@ -1091,12 +1142,47 @@ def test_types_groupby_aggregate() -> None: assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame), pd.DataFrame, ) + + def func(s: pd.Series[int]) -> float: + return s.astype(float).min() + + s = pd.Series([1, 2, 3, 4]) + check( + assert_type(s.groupby([1, 1, 2, 2]).agg(func), "pd.Series[float]"), + pd.Series, + np.floating, + ) + check( + assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"), + pd.Series, + np.floating, + ) + check( + assert_type( + s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]" + ), + pd.Series, + np.floating, + ) + + # test below fails with mypy but pyright correctly sees it as pd.Series[float] + # check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[float]"), pd.Series, float) + with pytest_warns_bounded( FutureWarning, r"The provided callable is currently using", upper="2.2.99", ): - check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series) + + def sum_sr(s: pd.Series[int]) -> int: + # type of `sum` not well inferred by mypy + return s.sum() + + check( + assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"), + pd.Series, + np.integer, + ) check( assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame), pd.DataFrame,