-
-
Notifications
You must be signed in to change notification settings - Fork 143
GH456 First attempt GroupBy.transform improved typing #1242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
020f93d
106a6f5
3bba101
053b7e7
4141a06
f9863d0
e26b4c1
96abf3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -7,6 +7,7 @@ from collections.abc import ( | |||
) | ||||
from typing import ( | ||||
Any, | ||||
Concatenate, | ||||
Generic, | ||||
Literal, | ||||
NamedTuple, | ||||
|
@@ -22,7 +23,10 @@ from pandas.core.groupby.groupby import ( | |||
GroupBy, | ||||
GroupByPlot, | ||||
) | ||||
from pandas.core.series import Series | ||||
from pandas.core.series import ( | ||||
Series, | ||||
UnknownSeries, | ||||
) | ||||
from typing_extensions import ( | ||||
Self, | ||||
TypeAlias, | ||||
|
@@ -31,15 +35,18 @@ from typing_extensions import ( | |||
from pandas._libs.tslibs.timestamps import Timestamp | ||||
from pandas._typing import ( | ||||
S1, | ||||
S2, | ||||
AggFuncTypeBase, | ||||
AggFuncTypeFrame, | ||||
ByT, | ||||
CorrelationMethod, | ||||
Dtype, | ||||
GroupByFuncStrs, | ||||
IndexLabel, | ||||
Level, | ||||
ListLike, | ||||
NsmallestNlargestKeep, | ||||
P, | ||||
Scalar, | ||||
TakeIndexer, | ||||
WindowingEngine, | ||||
|
@@ -53,10 +60,21 @@ class NamedAgg(NamedTuple): | |||
aggfunc: AggScalar | ||||
|
||||
class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): | ||||
@overload | ||||
def aggregate( | ||||
self, | ||||
func: Callable[Concatenate[Series[S1], P], S2], | ||||
/, | ||||
*args, | ||||
engine: WindowingEngine = ..., | ||||
engine_kwargs: WindowingEngineKwargs = ..., | ||||
**kwargs, | ||||
) -> Series[S2]: ... | ||||
@overload | ||||
def aggregate( | ||||
self, | ||||
func: list[AggFuncTypeBase], | ||||
/, | ||||
*args, | ||||
engine: WindowingEngine = ..., | ||||
engine_kwargs: WindowingEngineKwargs = ..., | ||||
|
@@ -66,20 +84,32 @@ class SeriesGroupBy(GroupBy[Series[S1]], Generic[S1, ByT]): | |||
def aggregate( | ||||
self, | ||||
func: AggFuncTypeBase | None = ..., | ||||
/, | ||||
*args, | ||||
Comment on lines
93
to
97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before this overload, you could add this overload: @overload
def aggregate(
self,
func: Callable[[Series], S2],
*args,
engine: WindowingEngine = ...,
engine_kwargs: WindowingEngineKwargs = ...,
**kwargs,
) -> Series[S2]: ... Then you know that if you start with a s = pd.Series([1, 2, 3, 4])
q = s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min()) In this case, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's because the type of But I think it would work if you did Because then it can know that Can you try that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried that for the last push, see pandas-stubs/tests/test_series.py Line 1167 in f9863d0
It fails in all CI:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I look with how
so that may explain why it fails on lambda expressions whatsoever. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK - so we can leave the |
||||
engine: WindowingEngine = ..., | ||||
engine_kwargs: WindowingEngineKwargs = ..., | ||||
**kwargs, | ||||
) -> Series: ... | ||||
) -> UnknownSeries: ... | ||||
agg = aggregate | ||||
@overload | ||||
def transform( | ||||
self, | ||||
func: Callable | str, | ||||
*args, | ||||
func: Callable[Concatenate[Series[S1], P], Series[S2]], | ||||
/, | ||||
*args: Any, | ||||
Dr-Irv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
engine: WindowingEngine = ..., | ||||
engine_kwargs: WindowingEngineKwargs = ..., | ||||
**kwargs, | ||||
) -> Series: ... | ||||
**kwargs: Any, | ||||
) -> Series[S2]: ... | ||||
@overload | ||||
def transform( | ||||
self, | ||||
func: Callable, | ||||
*args: Any, | ||||
**kwargs: Any, | ||||
) -> UnknownSeries: ... | ||||
@overload | ||||
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> UnknownSeries: ... | ||||
def filter( | ||||
self, func: Callable | str, dropna: bool = ..., *args, **kwargs | ||||
) -> Series: ... | ||||
|
@@ -206,14 +236,24 @@ class DataFrameGroupBy(GroupBy[DataFrame], Generic[ByT, _TT]): | |||
**kwargs, | ||||
) -> DataFrame: ... | ||||
agg = aggregate | ||||
@overload | ||||
def transform( | ||||
self, | ||||
func: Callable | str, | ||||
*args, | ||||
func: Callable[Concatenate[DataFrame, P], DataFrame], | ||||
*args: Any, | ||||
engine: WindowingEngine = ..., | ||||
Dr-Irv marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
engine_kwargs: WindowingEngineKwargs = ..., | ||||
**kwargs, | ||||
**kwargs: Any, | ||||
) -> DataFrame: ... | ||||
@overload | ||||
def transform( | ||||
self, | ||||
func: Callable, | ||||
*args: Any, | ||||
**kwargs: Any, | ||||
) -> DataFrame: ... | ||||
@overload | ||||
def transform(self, func: GroupByFuncStrs, *args, **kwargs) -> DataFrame: ... | ||||
def filter( | ||||
self, func: Callable, dropna: bool = ..., *args, **kwargs | ||||
) -> DataFrame: ... | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1078,25 +1078,90 @@ def test_types_groupby_agg() -> None: | |
r"The provided callable <built-in function (min|sum)> is currently using", | ||
upper="2.2.99", | ||
): | ||
check(assert_type(s.groupby(level=0).agg(sum), pd.Series), pd.Series) | ||
|
||
def sum_sr(s: pd.Series[int]) -> int: | ||
# type of `sum` not well inferred by mypy | ||
return sum(s) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue was if I passed |
||
|
||
check( | ||
assert_type(s.groupby(level=0).agg(sum_sr), "pd.Series[int]"), | ||
pd.Series, | ||
np.integer, | ||
) | ||
check( | ||
assert_type(s.groupby(level=0).agg([min, sum]), pd.DataFrame), pd.DataFrame | ||
) | ||
|
||
|
||
def test_types_groupby_transform() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should add tests for two of the string transform arguments (e.g., "mean", "first") |
||
s: pd.Series[int] = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) | ||
|
||
def transform_func( | ||
x: pd.Series[int], pos_arg: bool, kw_arg: str | ||
) -> pd.Series[float]: | ||
return x / (2.0 if pos_arg else 1.0) | ||
|
||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform(transform_func, True, kw_arg="foo"), | ||
"pd.Series[float]", | ||
), | ||
Dr-Irv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pd.Series, | ||
float, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(lambda x: x).transform( | ||
transform_func, True, engine="cython", kw_arg="foo" | ||
), | ||
"pd.Series[float]", | ||
), | ||
pd.Series, | ||
float, | ||
) | ||
|
||
|
||
def test_types_groupby_aggregate() -> None: | ||
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) | ||
check(assert_type(s.groupby(level=0).aggregate("sum"), pd.Series), pd.Series) | ||
check( | ||
assert_type(s.groupby(level=0).aggregate(["min", "sum"]), pd.DataFrame), | ||
pd.DataFrame, | ||
) | ||
|
||
def func(s: pd.Series[int]) -> float: | ||
return s.astype(float).min() | ||
|
||
s = pd.Series([1, 2, 3, 4]) | ||
s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't you want a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct my mistake There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out the inference on the fly of lambdas is not super clear so you need to define the function on the side to have the right types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that is an issue with lambda functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think you can have a test of check(assert_type( s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min()), pd.Series), pd.Series) which would be worthwhile |
||
check( | ||
assert_type(s.groupby(level=0).aggregate(func), "pd.Series[float]"), | ||
pd.Series, | ||
np.floating, | ||
) | ||
check( | ||
assert_type( | ||
s.groupby(level=0).aggregate(func, engine="cython"), "pd.Series[float]" | ||
), | ||
pd.Series, | ||
np.floating, | ||
) | ||
|
||
with pytest_warns_bounded( | ||
FutureWarning, | ||
r"The provided callable <built-in function (min|sum)> is currently using", | ||
upper="2.2.99", | ||
): | ||
check(assert_type(s.groupby(level=0).aggregate(sum), pd.Series), pd.Series) | ||
|
||
def sum_sr(s: pd.Series[int]) -> int: | ||
# type of `sum` not well inferred by mypy | ||
return sum(s) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use |
||
|
||
check( | ||
assert_type(s.groupby(level=0).aggregate(sum_sr), "pd.Series[int]"), | ||
pd.Series, | ||
np.integer, | ||
) | ||
check( | ||
assert_type(s.groupby(level=0).aggregate([min, sum]), pd.DataFrame), | ||
pd.DataFrame, | ||
|
Uh oh!
There was an error while loading. Please reload this page.