diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 3ad916f76..cb6fa46bb 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -24,7 +24,10 @@ from pandas.core.generic import NDFrame from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.series import Series -from typing_extensions import TypeAlias +from typing_extensions import ( + ParamSpec, + TypeAlias, +) from pandas._libs.interval import Interval from pandas._libs.tslibs import ( @@ -123,6 +126,7 @@ JSONSerializable: TypeAlias = Union[PythonScalar, list, dict] Axes: TypeAlias = Union[AnyArrayLike, list, dict, range, tuple] Renamer: TypeAlias = Union[Mapping[Any, Label], Callable[[Any], Label]] T = TypeVar("T") +P = ParamSpec("P") FuncType: TypeAlias = Callable[..., Any] F = TypeVar("F", bound=FuncType) HashableT = TypeVar("HashableT", bound=Hashable) @@ -202,6 +206,27 @@ S1 = TypeVar( Interval[Timestamp], Interval[Timedelta], ) +S2 = TypeVar( + "S2", + str, + bytes, + datetime.date, + datetime.datetime, + datetime.time, + datetime.timedelta, + bool, + int, + float, + complex, + Timestamp, + Timedelta, + np.datetime64, + Period, + Interval[int], + Interval[float], + Interval[Timestamp], + Interval[Timedelta], +) T1 = TypeVar( "T1", str, int, np.int64, np.uint64, np.float64, float, np.dtype[np.generic] ) @@ -285,6 +310,44 @@ GroupByObjectNonScalar: TypeAlias = Union[ list[Grouper], ] GroupByObject: TypeAlias = Union[Scalar, GroupByObjectNonScalar] +GroupByFuncStrs: TypeAlias = Literal[ + # Reduction/aggregation functions + "all", + "any", + "corrwith", + "count", + "first", + "idxmax", + "idxmin", + "last", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "quantile", + "sem", + "size", + "skew", + "std", + "sum", + "var", + # Transformation functions + "bfill", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "fillna", + "ngroup", + "pct_change", + "rank", + "shift", +] StataDateFormat: TypeAlias = Literal[ "tc", diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index c4a56f8b0..47b8ea3e5 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -22,15 +22,21 @@ from pandas.core.groupby.groupby import ( # , get_groupby as get_groupby ) from pandas.core.groupby.grouper import Grouper from pandas.core.series import Series -from typing_extensions import TypeAlias +from typing_extensions import ( + Concatenate, + TypeAlias, +) from pandas._typing import ( S1, + S2, AggFuncTypeBase, AggFuncTypeFrame, AxisType, + GroupByFuncStrs, Level, ListLike, + P, Scalar, ) @@ -61,7 +67,17 @@ class SeriesGroupBy(GroupBy, Generic[S1]): def agg(self, func: list[AggFuncTypeBase], *args, **kwargs) -> DataFrame: ... @overload def agg(self, func: AggFuncTypeBase, *args, **kwargs) -> Series: ... - def transform(self, func: Callable | str, *args, **kwargs) -> Series: ... + @overload + def transform( + self, + func: Callable[Concatenate[Series[S1], P], Series[S2]], + *args: P.args, + engine: Literal["cython", "numba", None] = ..., + engine_kwargs: dict[str, Any] | None = ..., + **kwargs: P.kwargs, + ) -> Series[S2]: ... + @overload + def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ... def filter(self, func, dropna: bool = ..., *args, **kwargs): ... def nunique(self, dropna: bool = ...) -> Series: ... def describe(self, **kwargs) -> DataFrame: ... diff --git a/tests/test_series.py b/tests/test_series.py index 12beb716d..c5d06f6d8 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -517,6 +517,24 @@ def test_types_groupby_methods() -> None: check(assert_type(s.groupby(level=0).unique(), pd.Series), pd.Series) +def test_types_groupby_transform() -> None: + s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) + + def transform_func( + x: pd.Series[int], pos_arg: bool, kw_arg: str + ) -> pd.Series[float]: + return x / (2.0 if pos_arg else 1.0) + + check( + assert_type( + s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"), + "pd.Series[float]", + ), + pd.Series, + float, + ) + + def test_types_groupby_agg() -> None: s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) check(assert_type(s.groupby(level=0).agg("sum"), pd.Series), pd.Series) @@ -641,9 +659,9 @@ def test_types_aggregate() -> None: def test_types_transform() -> None: - s = pd.Series([1, 2, 3], index=["col1", "col2", "col3"]) - check(assert_type(s.transform("abs"), pd.Series), pd.Series) - check(assert_type(s.transform(abs), pd.Series), pd.Series) + s: pd.Series[int] = pd.Series([1, 2, 3], index=["col1", "col2", "col3"]) + check(assert_type(s.transform("abs"), "pd.Series[int]"), pd.Series, int) + check(assert_type(s.transform(abs), "pd.Series[int]"), pd.Series, int) check(assert_type(s.transform(["abs", "sqrt"]), pd.DataFrame), pd.DataFrame) check(assert_type(s.transform([abs, np.sqrt]), pd.DataFrame), pd.DataFrame) check(