Skip to content

Commit 79295aa

Browse files
committed
Make Series.groupby.transform annotation more precise
1 parent 09553b0 commit 79295aa

File tree

3 files changed

+105
-6
lines changed

3 files changed

+105
-6
lines changed

pandas-stubs/_typing.pyi

+64-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ from pandas.core.generic import NDFrame
2424
from pandas.core.groupby.grouper import Grouper
2525
from pandas.core.indexes.base import Index
2626
from pandas.core.series import Series
27-
from typing_extensions import TypeAlias
27+
from typing_extensions import (
28+
ParamSpec,
29+
TypeAlias,
30+
)
2831

2932
from pandas._libs.interval import Interval
3033
from pandas._libs.tslibs import (
@@ -123,6 +126,7 @@ JSONSerializable: TypeAlias = Union[PythonScalar, list, dict]
123126
Axes: TypeAlias = Union[AnyArrayLike, list, dict, range, tuple]
124127
Renamer: TypeAlias = Union[Mapping[Any, Label], Callable[[Any], Label]]
125128
T = TypeVar("T")
129+
P = ParamSpec("P")
126130
FuncType: TypeAlias = Callable[..., Any]
127131
F = TypeVar("F", bound=FuncType)
128132
HashableT = TypeVar("HashableT", bound=Hashable)
@@ -202,6 +206,27 @@ S1 = TypeVar(
202206
Interval[Timestamp],
203207
Interval[Timedelta],
204208
)
209+
S2 = TypeVar(
210+
"S2",
211+
str,
212+
bytes,
213+
datetime.date,
214+
datetime.datetime,
215+
datetime.time,
216+
datetime.timedelta,
217+
bool,
218+
int,
219+
float,
220+
complex,
221+
Timestamp,
222+
Timedelta,
223+
np.datetime64,
224+
Period,
225+
Interval[int],
226+
Interval[float],
227+
Interval[Timestamp],
228+
Interval[Timedelta],
229+
)
205230
T1 = TypeVar(
206231
"T1", str, int, np.int64, np.uint64, np.float64, float, np.dtype[np.generic]
207232
)
@@ -285,6 +310,44 @@ GroupByObjectNonScalar: TypeAlias = Union[
285310
list[Grouper],
286311
]
287312
GroupByObject: TypeAlias = Union[Scalar, GroupByObjectNonScalar]
313+
GroupByFuncStrs: TypeAlias = Literal[
314+
# Reduction/aggregation functions
315+
"all",
316+
"any",
317+
"corrwith",
318+
"count",
319+
"first",
320+
"idxmax",
321+
"idxmin",
322+
"last",
323+
"max",
324+
"mean",
325+
"median",
326+
"min",
327+
"nunique",
328+
"prod",
329+
"quantile",
330+
"sem",
331+
"size",
332+
"skew",
333+
"std",
334+
"sum",
335+
"var",
336+
# Transformation functions
337+
"bfill",
338+
"cumcount",
339+
"cummax",
340+
"cummin",
341+
"cumprod",
342+
"cumsum",
343+
"diff",
344+
"ffill",
345+
"fillna",
346+
"ngroup",
347+
"pct_change",
348+
"rank",
349+
"shift",
350+
]
288351

289352
StataDateFormat: TypeAlias = Literal[
290353
"tc",

pandas-stubs/core/groupby/generic.pyi

+20-2
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,21 @@ from pandas.core.groupby.groupby import ( # , get_groupby as get_groupby
2222
)
2323
from pandas.core.groupby.grouper import Grouper
2424
from pandas.core.series import Series
25-
from typing_extensions import TypeAlias
25+
from typing_extensions import (
26+
Concatenate,
27+
TypeAlias,
28+
)
2629

2730
from pandas._typing import (
2831
S1,
32+
S2,
2933
AggFuncTypeBase,
3034
AggFuncTypeFrame,
3135
AxisType,
36+
GroupByFuncStrs,
3237
Level,
3338
ListLike,
39+
P,
3440
Scalar,
3541
)
3642

@@ -61,7 +67,19 @@ class SeriesGroupBy(GroupBy, Generic[S1]):
6167
def agg(self, func: list[AggFuncTypeBase], *args, **kwargs) -> DataFrame: ...
6268
@overload
6369
def agg(self, func: AggFuncTypeBase, *args, **kwargs) -> Series: ...
64-
def transform(self, func: Callable | str, *args, **kwargs) -> Series: ...
70+
@overload
71+
def transform(
72+
self,
73+
func: Callable[Concatenate[Series[S1], P], Series[S2]],
74+
*args: Any, # Ideally we would use P.args here, but it does not work with engine/engine_kwargs present
75+
engine: Literal["cython", "numba", None] = ...,
76+
engine_kwargs: dict[str, Any] | None = ...,
77+
**kwargs: Any, # Ideally we would use P.kwargs here, but does not work with engine/engine_kwargs present
78+
) -> Series[S2]: ...
79+
@overload
80+
# TODO: We may want to consider limiting this to transformation functions only as while using an aggregation
81+
# function like "sum" works, it makes no sense to use it in conjunction with transform()
82+
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> Series: ...
6583
def filter(self, func, dropna: bool = ..., *args, **kwargs): ...
6684
def nunique(self, dropna: bool = ...) -> Series: ...
6785
def describe(self, **kwargs) -> DataFrame: ...

tests/test_series.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,24 @@ def test_types_groupby_methods() -> None:
517517
check(assert_type(s.groupby(level=0).unique(), pd.Series), pd.Series)
518518

519519

520+
def test_types_groupby_transform() -> None:
521+
s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
522+
523+
def transform_func(
524+
x: pd.Series[int], pos_arg: bool, kw_arg: str
525+
) -> pd.Series[float]:
526+
return x / (2.0 if pos_arg else 1.0)
527+
528+
check(
529+
assert_type(
530+
s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"),
531+
"pd.Series[float]",
532+
),
533+
pd.Series,
534+
float,
535+
)
536+
537+
520538
def test_types_groupby_agg() -> None:
521539
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
522540
check(assert_type(s.groupby(level=0).agg("sum"), pd.Series), pd.Series)
@@ -641,9 +659,9 @@ def test_types_aggregate() -> None:
641659

642660

643661
def test_types_transform() -> None:
644-
s = pd.Series([1, 2, 3], index=["col1", "col2", "col3"])
645-
check(assert_type(s.transform("abs"), pd.Series), pd.Series)
646-
check(assert_type(s.transform(abs), pd.Series), pd.Series)
662+
s: pd.Series[int] = pd.Series([1, 2, 3], index=["col1", "col2", "col3"])
663+
check(assert_type(s.transform("abs"), "pd.Series[int]"), pd.Series, int)
664+
check(assert_type(s.transform(abs), "pd.Series[int]"), pd.Series, int)
647665
check(assert_type(s.transform(["abs", "sqrt"]), pd.DataFrame), pd.DataFrame)
648666
check(assert_type(s.transform([abs, np.sqrt]), pd.DataFrame), pd.DataFrame)
649667
check(

0 commit comments

Comments
 (0)