From 020f93d8bc426aebe8fdc6f2a07e19ff289a6449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Fri, 6 Jun 2025 17:15:50 -0400 Subject: [PATCH 1/8] GH456 First attempt GroupBy.transform improved typing --- pandas-stubs/_typing.pyi | 38 +++++++++++++++++++++++++++ pandas-stubs/core/groupby/generic.pyi | 36 ++++++++++++++++++++----- tests/test_series.py | 18 +++++++++++++ 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index c404fc74..dfae7998 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -925,6 +925,44 @@ GroupByObjectNonScalar: TypeAlias = ( | list[Grouper] ) GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series +GroupByFuncStrs: TypeAlias = Literal[ + # Reduction/aggregation functions + "all", + "any", + "corrwith", + "count", + "first", + "idxmax", + "idxmin", + "last", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "quantile", + "sem", + "size", + "skew", + "std", + "sum", + "var", + # Transformation functions + "bfill", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "fillna", + "ngroup", + "pct_change", + "rank", + "shift", +] StataDateFormat: TypeAlias = Literal[ "tc", diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index f618a592..1b6ae01d 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -7,6 +7,7 @@ from collections.abc import ( ) from typing import ( Any, + Concatenate, Generic, Literal, NamedTuple, @@ -31,15 +32,18 @@ from typing_extensions import ( from pandas._libs.tslibs.timestamps import Timestamp from pandas._typing import ( S1, + S2, AggFuncTypeBase, AggFuncTypeFrame, ByT, CorrelationMethod, Dtype, + GroupByFuncStrs, IndexLabel, Level, ListLike, NsmallestNlargestKeep, + P, Scalar, TakeIndexer, WindowingEngine, @@ -72,14 +76,24 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): **kwargs, ) -> Series: ... agg = aggregate + @overload def transform( self, - func: Callable | str, - *args, + func: Callable[Concatenate[Series[S1], P], Series[S2]], + *args: Any, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., - **kwargs, + **kwargs: Any, + ) -> Series[S2]: ... + @overload + def transform( + self, + func: Callable, + *args: Any, + **kwargs: Any, ) -> Series: ... + @overload + def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ... def filter( self, func: Callable | str, dropna: bool = ..., *args, **kwargs ) -> Series: ... @@ -206,14 +220,24 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): **kwargs, ) -> DataFrame: ... agg = aggregate + @overload def transform( self, - func: Callable | str, - *args, + func: Callable[Concatenate[DataFrame, P], DataFrame], + *args: Any, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., - **kwargs, + **kwargs: Any, ) -> DataFrame: ... + @overload + def transform( + self, + func: Callable, + *args: Any, + **kwargs: Any, + ) -> DataFrame: ... + @overload + def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> DataFrame: ... def filter( self, func: Callable, dropna: bool = ..., *args, **kwargs ) -> DataFrame: ... diff --git a/tests/test_series.py b/tests/test_series.py index 57051997..10046d61 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1084,6 +1084,24 @@ def test_types_groupby_agg() -> None: ) +def test_types_groupby_transform() -> None: + s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) + + def transform_func( + x: pd.Series[int], pos_arg: bool, kw_arg: str + ) -> pd.Series[float]: + return x / (2.0 if pos_arg else 1.0) + + check( + assert_type( + s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"), + "pd.Series[float]", + ), + pd.Series, + float, + ) + + def test_types_groupby_aggregate() -> None: s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series) From 106a6f529046d3a133b7eb7077a3c22ec80672ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Fri, 6 Jun 2025 19:23:59 -0400 Subject: [PATCH 2/8] GH456 Attempt GroupBy.aggregate improved typing --- pandas-stubs/core/groupby/generic.pyi | 24 ++++++++++--- tests/test_series.py | 51 +++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 1b6ae01d..793cc01a 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -23,7 +23,10 @@ from pandas.core.groupby.groupby import ( GroupBy, GroupByPlot, ) -from pandas.core.series import Series +from pandas.core.series import ( + Series, + UnknownSeries, +) from typing_extensions import ( Self, TypeAlias, @@ -57,10 +60,21 @@ class NamedAgg(NamedTuple): aggfunc: AggScalar class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): + @overload + def aggregate( + self, + func: Callable[Concatenate[Series[S1], P], S2], + /, + *args, + engine: WindowingEngine = ..., + engine_kwargs: WindowingEngineKwargs = ..., + **kwargs, + ) -> Series[S2]: ... @overload def aggregate( self, func: list[AggFuncTypeBase], + /, *args, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., @@ -70,16 +84,18 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): def aggregate( self, func: AggFuncTypeBase | None = ..., + /, *args, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., **kwargs, - ) -> Series: ... + ) -> UnknownSeries: ... agg = aggregate @overload def transform( self, func: Callable[Concatenate[Series[S1], P], Series[S2]], + /, *args: Any, engine: WindowingEngine = ..., engine_kwargs: WindowingEngineKwargs = ..., @@ -91,9 +107,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): func: Callable, *args: Any, **kwargs: Any, - ) -> Series: ... + ) -> UnknownSeries: ... @overload - def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ... + def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> UnknownSeries: ... def filter( self, func: Callable | str, dropna: bool = ..., *args, **kwargs ) -> Series: ... diff --git a/tests/test_series.py b/tests/test_series.py index 10046d61..66219342 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1078,7 +1078,16 @@ def test_types_groupby_agg() -> None: r"The provided callable is currently using", upper="2.2.99", ): - check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series) + + def sum_sr(s: pd.Series[int]) -> int: + # type of `sum` not well inferred by mypy + return sum(s) + + check( + assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"), + pd.Series, + np.integer, + ) check( assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame ) @@ -1100,6 +1109,16 @@ def transform_func( pd.Series, float, ) + check( + assert_type( + s.groupby(lambda x: x).transform( + transform_func, True, engine="cython", kw_arg="foo" + ), + "pd.Series[float]", + ), + pd.Series, + float, + ) def test_types_groupby_aggregate() -> None: @@ -1109,12 +1128,40 @@ def test_types_groupby_aggregate() -> None: assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame), pd.DataFrame, ) + + def func(s: pd.Series[int]) -> float: + return s.astype(float).min() + + s = pd.Series([1, 2, 3, 4]) + s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min()) + check( + assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"), + pd.Series, + np.floating, + ) + check( + assert_type( + s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]" + ), + pd.Series, + np.floating, + ) + with pytest_warns_bounded( FutureWarning, r"The provided callable is currently using", upper="2.2.99", ): - check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series) + + def sum_sr(s: pd.Series[int]) -> int: + # type of `sum` not well inferred by mypy + return sum(s) + + check( + assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"), + pd.Series, + np.integer, + ) check( assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame), pd.DataFrame, From 3bba10157431efa180ebe1288e1a72a5a82dc103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Fri, 6 Jun 2025 19:38:10 -0400 Subject: [PATCH 3/8] GH456 Attempt GroupBy.aggregate improved typing --- pandas-stubs/_typing.pyi | 38 --------------------- pandas-stubs/core/groupby/base.pyi | 49 +++++++++++++++++++++++++++ pandas-stubs/core/groupby/generic.pyi | 10 ++++-- 3 files changed, 56 insertions(+), 41 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index dfae7998..c404fc74 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -925,44 +925,6 @@ GroupByObjectNonScalar: TypeAlias = ( | list[Grouper] ) GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar | Series -GroupByFuncStrs: TypeAlias = Literal[ - # Reduction/aggregation functions - "all", - "any", - "corrwith", - "count", - "first", - "idxmax", - "idxmin", - "last", - "max", - "mean", - "median", - "min", - "nunique", - "prod", - "quantile", - "sem", - "size", - "skew", - "std", - "sum", - "var", - # Transformation functions - "bfill", - "cumcount", - "cummax", - "cummin", - "cumprod", - "cumsum", - "diff", - "ffill", - "fillna", - "ngroup", - "pct_change", - "rank", - "shift", -] StataDateFormat: TypeAlias = Literal[ "tc", diff --git a/pandas-stubs/core/groupby/base.pyi b/pandas-stubs/core/groupby/base.pyi index f56b6a32..d75ec559 100644 --- a/pandas-stubs/core/groupby/base.pyi +++ b/pandas-stubs/core/groupby/base.pyi @@ -1,7 +1,56 @@ from collections.abc import Hashable import dataclasses +from typing import ( + Literal, + TypeAlias, +) @dataclasses.dataclass(order=True, frozen=True) class OutputKey: label: Hashable position: int + +reduction_kernels: TypeAlias = Literal[ + "all", + "any", + "corrwith", + "count", + "first", + "idxmax", + "idxmin", + "last", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + # as long as `quantile`'s signature accepts only + # a single quantile value, it's a reduction. + # GH#27526 might change that. + "quantile", + "sem", + "size", + "skew", + "std", + "sum", + "var", +] + +transformation_kernels: TypeAlias = Literal[ + "bfill", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "fillna", + "ngroup", + "pct_change", + "rank", + "shift", +] + +transform_kernel_allowlist: TypeAlias = reduction_kernels | transformation_kernels diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 793cc01a..c37206f2 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -19,6 +19,7 @@ from typing import ( from matplotlib.axes import Axes as PlotAxes import numpy as np from pandas.core.frame import DataFrame +from pandas.core.groupby.base import transform_kernel_allowlist from pandas.core.groupby.groupby import ( GroupBy, GroupByPlot, @@ -41,7 +42,6 @@ from pandas._typing import ( ByT, CorrelationMethod, Dtype, - GroupByFuncStrs, IndexLabel, Level, ListLike, @@ -109,7 +109,9 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): **kwargs: Any, ) -> UnknownSeries: ... @overload - def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> UnknownSeries: ... + def transform( + self, func: transform_kernel_allowlist, *args, **kwargs + ) -> UnknownSeries: ... def filter( self, func: Callable | str, dropna: bool = ..., *args, **kwargs ) -> Series: ... @@ -253,7 +255,9 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): **kwargs: Any, ) -> DataFrame: ... @overload - def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> DataFrame: ... + def transform( + self, func: transform_kernel_allowlist, *args, **kwargs + ) -> DataFrame: ... def filter( self, func: Callable, dropna: bool = ..., *args, **kwargs ) -> DataFrame: ... From 053b7e75346efbb0c1967200a8a594302fec5aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Sat, 7 Jun 2025 11:10:29 -0400 Subject: [PATCH 4/8] GH456 PR Feedback --- pandas-stubs/core/groupby/base.pyi | 6 +++--- pandas-stubs/core/groupby/generic.pyi | 6 +++--- tests/test_series.py | 10 +++++++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pandas-stubs/core/groupby/base.pyi b/pandas-stubs/core/groupby/base.pyi index d75ec559..97faad2b 100644 --- a/pandas-stubs/core/groupby/base.pyi +++ b/pandas-stubs/core/groupby/base.pyi @@ -10,7 +10,7 @@ class OutputKey: label: Hashable position: int -reduction_kernels: TypeAlias = Literal[ +ReductionKernelType: TypeAlias = Literal[ "all", "any", "corrwith", @@ -37,7 +37,7 @@ reduction_kernels: TypeAlias = Literal[ "var", ] -transformation_kernels: TypeAlias = Literal[ +TransformationKernelType: TypeAlias = Literal[ "bfill", "cumcount", "cummax", @@ -53,4 +53,4 @@ transformation_kernels: TypeAlias = Literal[ "shift", ] -transform_kernel_allowlist: TypeAlias = reduction_kernels | transformation_kernels +TransformReductionListType: TypeAlias = ReductionKernelType | TransformationKernelType diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index c37206f2..b582ddd0 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -19,7 +19,7 @@ from typing import ( from matplotlib.axes import Axes as PlotAxes import numpy as np from pandas.core.frame import DataFrame -from pandas.core.groupby.base import transform_kernel_allowlist +from pandas.core.groupby.base import TransformReductionListType from pandas.core.groupby.groupby import ( GroupBy, GroupByPlot, @@ -110,7 +110,7 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): ) -> UnknownSeries: ... @overload def transform( - self, func: transform_kernel_allowlist, *args, **kwargs + self, func: TransformReductionListType, *args, **kwargs ) -> UnknownSeries: ... def filter( self, func: Callable | str, dropna: bool = ..., *args, **kwargs @@ -256,7 +256,7 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): ) -> DataFrame: ... @overload def transform( - self, func: transform_kernel_allowlist, *args, **kwargs + self, func: TransformReductionListType, *args, **kwargs ) -> DataFrame: ... def filter( self, func: Callable, dropna: bool = ..., *args, **kwargs diff --git a/tests/test_series.py b/tests/test_series.py index 66219342..ff84cec1 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1081,7 +1081,7 @@ def test_types_groupby_agg() -> None: def sum_sr(s: pd.Series[int]) -> int: # type of `sum` not well inferred by mypy - return sum(s) + return s.sum() check( assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"), @@ -1133,7 +1133,11 @@ def func(s: pd.Series[int]) -> float: return s.astype(float).min() s = pd.Series([1, 2, 3, 4]) - s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min()) + check( + assert_type(s.groupby([1, 1, 2, 2]).agg(func), "pd.Series[float]"), + pd.Series, + np.floating, + ) check( assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"), pd.Series, @@ -1155,7 +1159,7 @@ def func(s: pd.Series[int]) -> float: def sum_sr(s: pd.Series[int]) -> int: # type of `sum` not well inferred by mypy - return sum(s) + return s.sum() check( assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"), From 4141a06e53880dba16f0ac5f30f77167f9e295d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Sat, 7 Jun 2025 16:16:34 -0400 Subject: [PATCH 5/8] GH456 PR Feedback --- pandas-stubs/core/groupby/generic.pyi | 9 +++++++++ tests/test_series.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index b582ddd0..920c962f 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -71,6 +71,15 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): **kwargs, ) -> Series[S2]: ... @overload + def aggregate( + self, + func: Callable[[Series], S2], + *args, + engine: WindowingEngine = ..., + engine_kwargs: WindowingEngineKwargs = ..., + **kwargs, + ) -> Series[S2]: ... + @overload def aggregate( self, func: list[AggFuncTypeBase], diff --git a/tests/test_series.py b/tests/test_series.py index ff84cec1..40c80027 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1119,6 +1119,20 @@ def transform_func( pd.Series, float, ) + check( + assert_type( + s.groupby(lambda x: x).transform("mean"), + "pd.Series", + ), + pd.Series, + ) + check( + assert_type( + s.groupby(lambda x: x).transform("first"), + "pd.Series", + ), + pd.Series, + ) def test_types_groupby_aggregate() -> None: From f9863d03c387dcca86da8f86675a105a5d837786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Sat, 7 Jun 2025 22:11:00 -0400 Subject: [PATCH 6/8] GH456 PR Feedback --- tests/test_series.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_series.py b/tests/test_series.py index 40c80027..f5643ff4 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1164,6 +1164,7 @@ def func(s: pd.Series[int]) -> float: pd.Series, np.floating, ) + check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[float]"), pd.Series, int) with pytest_warns_bounded( FutureWarning, From e26b4c1118545a44bff495a2a305314861366cdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Mon, 9 Jun 2025 20:35:05 -0400 Subject: [PATCH 7/8] GH456 PR Feedback --- tests/test_series.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_series.py b/tests/test_series.py index f5643ff4..12cc0b6c 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1164,7 +1164,9 @@ def func(s: pd.Series[int]) -> float: pd.Series, np.floating, ) - check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[float]"), pd.Series, int) + + # test below passes with mypy but pyright correctly sees it as pd.Series[float] + # check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), pd.Series), pd.Series, float) with pytest_warns_bounded( FutureWarning, From 96abf3b717800cf38fd607771ef12532dfdf01e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Diridollou?= Date: Wed, 11 Jun 2025 13:08:30 -0400 Subject: [PATCH 8/8] GH456 PR Feedback --- tests/test_series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 12cc0b6c..161aedf4 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1165,8 +1165,8 @@ def func(s: pd.Series[int]) -> float: np.floating, ) - # test below passes with mypy but pyright correctly sees it as pd.Series[float] - # check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), pd.Series), pd.Series, float) + # test below fails with mypy but pyright correctly sees it as pd.Series[float] + # check(assert_type(s.groupby([1,1,2,2]).agg(lambda x: x.astype(float).min()), "pd.Series[float]"), pd.Series, float) with pytest_warns_bounded( FutureWarning,