Skip to content

Commit f051cd7

Browse files
authored
Allow covariance in the agg dict passed to DataFrame or Series groupby.agg() (#363)
* Allow covariance in the agg dict passed to DataFrame or Series groupby.agg() * Add test cases for agg dicts with keys which are sub-types of Hashable and adjust type annotation to support this * Remove/adjust agg dict annotations on test * Change lambda to a local function
1 parent 5951319 commit f051cd7

File tree

2 files changed

+30
-18
lines changed

2 files changed

+30
-18
lines changed

pandas-stubs/_typing.pyi

+3-3
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ HashableT4 = TypeVar("HashableT4", bound=Hashable)
132132
HashableT5 = TypeVar("HashableT5", bound=Hashable)
133133

134134
AggFuncTypeBase: TypeAlias = Union[Callable, str, np.ufunc]
135-
AggFuncTypeDictSeries: TypeAlias = dict[Hashable, AggFuncTypeBase]
136-
AggFuncTypeDictFrame: TypeAlias = dict[
137-
Hashable, Union[AggFuncTypeBase, list[AggFuncTypeBase]]
135+
AggFuncTypeDictSeries: TypeAlias = Mapping[HashableT, AggFuncTypeBase]
136+
AggFuncTypeDictFrame: TypeAlias = Mapping[
137+
HashableT, Union[AggFuncTypeBase, list[AggFuncTypeBase]]
138138
]
139139
AggFuncTypeSeriesToFrame: TypeAlias = Union[
140140
list[AggFuncTypeBase],

tests/test_frame.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,9 @@ def test_types_groupby_methods() -> None:
643643

644644

645645
def test_types_groupby_agg() -> None:
646-
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
646+
df = pd.DataFrame(
647+
data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0], 0: [-1, -1, -1]}
648+
)
647649
check(assert_type(df.groupby("col1")["col3"].agg(min), pd.Series), pd.Series)
648650
check(
649651
assert_type(df.groupby("col1")["col3"].agg([min, max]), pd.DataFrame),
@@ -655,21 +657,28 @@ def test_types_groupby_agg() -> None:
655657
assert_type(df.groupby("col1").agg(["min", "max"]), pd.DataFrame), pd.DataFrame
656658
)
657659
check(assert_type(df.groupby("col1").agg([min, max]), pd.DataFrame), pd.DataFrame)
660+
agg_dict1 = {"col2": "min", "col3": "max", 0: "sum"}
661+
check(assert_type(df.groupby("col1").agg(agg_dict1), pd.DataFrame), pd.DataFrame)
662+
agg_dict2 = {"col2": min, "col3": max, 0: min}
663+
check(assert_type(df.groupby("col1").agg(agg_dict2), pd.DataFrame), pd.DataFrame)
664+
665+
def wrapped_min(x: Any) -> Any:
666+
return x.min()
667+
668+
# Here, MyPy infers dict[object, object], so it must be explicitly annotated
669+
agg_dict3: dict[str | int, str | Callable] = {
670+
"col2": min,
671+
"col3": "max",
672+
0: wrapped_min,
673+
}
674+
check(assert_type(df.groupby("col1").agg(agg_dict3), pd.DataFrame), pd.DataFrame)
675+
agg_dict4 = {"col2": "sum"}
676+
check(assert_type(df.groupby("col1").agg(agg_dict4), pd.DataFrame), pd.DataFrame)
677+
agg_dict5 = {0: "sum"}
678+
check(assert_type(df.groupby("col1").agg(agg_dict5), pd.DataFrame), pd.DataFrame)
679+
named_agg = pd.NamedAgg(column="col2", aggfunc="max")
658680
check(
659-
assert_type(
660-
df.groupby("col1").agg({"col2": "min", "col3": "max"}), pd.DataFrame
661-
),
662-
pd.DataFrame,
663-
)
664-
check(
665-
assert_type(df.groupby("col1").agg({"col2": min, "col3": max}), pd.DataFrame),
666-
pd.DataFrame,
667-
)
668-
check(
669-
assert_type(
670-
df.groupby("col1").agg(new_col=pd.NamedAgg(column="col2", aggfunc="max")),
671-
pd.DataFrame,
672-
),
681+
assert_type(df.groupby("col1").agg(new_col=named_agg), pd.DataFrame),
673682
pd.DataFrame,
674683
)
675684
# GH#187
@@ -679,6 +688,9 @@ def test_types_groupby_agg() -> None:
679688
cols_opt: list[str | None] = ["col1", "col2"]
680689
check(assert_type(df.groupby(by=cols_opt).sum(), pd.DataFrame), pd.DataFrame)
681690

691+
cols_mixed: list[str | int] = ["col1", 0]
692+
check(assert_type(df.groupby(by=cols_mixed).sum(), pd.DataFrame), pd.DataFrame)
693+
682694

683695
# This was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
684696
def test_types_group_by_with_dropna_keyword() -> None:

0 commit comments

Comments
 (0)