diff --git a/pandas/core/aggregation.py b/pandas/core/aggregation.py index f6380808d5ac2..d947c4cf2abfa 100644 --- a/pandas/core/aggregation.py +++ b/pandas/core/aggregation.py @@ -5,13 +5,18 @@ from collections import defaultdict from functools import partial -from typing import Any, DefaultDict, List, Sequence, Tuple +from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Sequence, Tuple + +from pandas._typing import Scalar from pandas.core.dtypes.common import is_dict_like, is_list_like import pandas.core.common as com from pandas.core.indexes.api import Index +if TYPE_CHECKING: + import numpy as np # noqa: F401 + def is_multi_agg_with_relabel(**kwargs) -> bool: """ @@ -39,7 +44,9 @@ def is_multi_agg_with_relabel(**kwargs) -> bool: ) -def normalize_keyword_aggregation(kwargs: dict) -> Tuple[dict, List[str], List[int]]: +def normalize_keyword_aggregation( + kwargs: Dict[str, str] +) -> Tuple[DefaultDict[str, List[Scalar]], Tuple[str, ...], "np.ndarray"]: """ Normalize user-provided "named aggregation" kwargs. Transforms from the new ``Mapping[str, NamedAgg]`` style kwargs @@ -51,11 +58,11 @@ def normalize_keyword_aggregation(kwargs: dict) -> Tuple[dict, List[str], List[i Returns ------- - aggspec : dict + aggspec : collections.defaultdict of lists The transformed kwargs. - columns : List[str] + columns : tuple The user-provided keys. - col_idx_order : List[int] + col_idx_order : numpy.ndarray List of columns indices. Examples diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 208cbfc5b06d6..d2d92e98e7f0a 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -14,11 +14,13 @@ TYPE_CHECKING, Any, Callable, + DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, + Optional, Sequence, Tuple, Type, @@ -30,7 +32,7 @@ import numpy as np from pandas._libs import Timestamp, lib -from pandas._typing import FrameOrSeries +from pandas._typing import FrameOrSeries, Scalar from pandas.util._decorators import Appender, Substitution, doc from pandas.core.dtypes.cast import ( @@ -909,7 +911,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]): axis="", ) @Appender(_shared_docs["aggregate"]) - def aggregate(self, func=None, *args, **kwargs): + def aggregate( + self, func: Optional[DefaultDict[str, List[Scalar]]] = None, *args, **kwargs + ): relabeling = func is None and is_multi_agg_with_relabel(**kwargs) if relabeling: