Skip to content

Commit 57b7213

Browse files
committed
Allow covariance in the agg dict passed to DataFrame or Series groupby.agg()
1 parent b12acc1 commit 57b7213

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pandas-stubs/_typing.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ F = TypeVar("F", bound=FuncType)
126126
HashableT = TypeVar("HashableT", bound=Hashable)
127127

128128
AggFuncTypeBase: TypeAlias = Union[Callable, str, np.ufunc]
129-
AggFuncTypeDictSeries: TypeAlias = dict[Hashable, AggFuncTypeBase]
130-
AggFuncTypeDictFrame: TypeAlias = dict[
129+
AggFuncTypeDictSeries: TypeAlias = Mapping[Hashable, AggFuncTypeBase]
130+
AggFuncTypeDictFrame: TypeAlias = Mapping[
131131
Hashable, Union[AggFuncTypeBase, list[AggFuncTypeBase]]
132132
]
133133
AggFuncTypeSeriesToFrame: TypeAlias = Union[

tests/test_frame.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
from typing_extensions import assert_type
3434
import xarray as xr
3535

36-
from pandas._typing import Scalar
36+
from pandas._typing import (
37+
AggFuncTypeBase,
38+
Scalar,
39+
)
3740

3841
from tests import (
3942
TYPE_CHECKING_INVALID_USAGE,
@@ -643,7 +646,9 @@ def test_types_groupby_methods() -> None:
643646

644647

645648
def test_types_groupby_agg() -> None:
646-
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0]})
649+
df = pd.DataFrame(
650+
data={"col1": [1, 1, 2], "col2": [3, 4, 5], "col3": [0, 1, 0], 0: [-1, -1, -1]}
651+
)
647652
check(assert_type(df.groupby("col1")["col3"].agg(min), pd.Series), pd.Series)
648653
check(
649654
assert_type(df.groupby("col1")["col3"].agg([min, max]), pd.DataFrame),
@@ -655,21 +660,19 @@ def test_types_groupby_agg() -> None:
655660
assert_type(df.groupby("col1").agg(["min", "max"]), pd.DataFrame), pd.DataFrame
656661
)
657662
check(assert_type(df.groupby("col1").agg([min, max]), pd.DataFrame), pd.DataFrame)
663+
agg_dict1: dict[Hashable, str] = {"col2": "min", "col3": "max", 0: "avg"}
664+
check(assert_type(df.groupby("col1").agg(agg_dict1), pd.DataFrame), pd.DataFrame)
665+
agg_dict2: dict[Hashable, AggFuncTypeBase] = {"col2": min, "col3": max, 0: min}
666+
check(assert_type(df.groupby("col1").agg(agg_dict2), pd.DataFrame), pd.DataFrame)
667+
agg_dict3: dict[Hashable, str | AggFuncTypeBase] = {
668+
"col2": min,
669+
"col3": "max",
670+
0: lambda x: x.min(),
671+
}
672+
check(assert_type(df.groupby("col1").agg(agg_dict3), pd.DataFrame), pd.DataFrame)
673+
named_agg = pd.NamedAgg(column="col2", aggfunc="max")
658674
check(
659-
assert_type(
660-
df.groupby("col1").agg({"col2": "min", "col3": "max"}), pd.DataFrame
661-
),
662-
pd.DataFrame,
663-
)
664-
check(
665-
assert_type(df.groupby("col1").agg({"col2": min, "col3": max}), pd.DataFrame),
666-
pd.DataFrame,
667-
)
668-
check(
669-
assert_type(
670-
df.groupby("col1").agg(new_col=pd.NamedAgg(column="col2", aggfunc="max")),
671-
pd.DataFrame,
672-
),
675+
assert_type(df.groupby("col1").agg(new_col=named_agg), pd.DataFrame),
673676
pd.DataFrame,
674677
)
675678
# GH#187
@@ -679,6 +682,9 @@ def test_types_groupby_agg() -> None:
679682
cols_opt: list[str | None] = ["col1", "col2"]
680683
check(assert_type(df.groupby(by=cols_opt).sum(), pd.DataFrame), pd.DataFrame)
681684

685+
cols_mixed: list[str | int] = ["col1", 0]
686+
check(assert_type(df.groupby(by=cols_mixed).sum(), pd.DataFrame), pd.DataFrame)
687+
682688

683689
# This was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
684690
def test_types_group_by_with_dropna_keyword() -> None:

0 commit comments

Comments
 (0)