diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 6bdb6f91e..759146c55 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -404,6 +404,29 @@ Function: TypeAlias = np.ufunc | Callable[..., Any] # shared HashableT and HashableT#. This one can be used if the identical # type is need in a function that uses GroupByObjectNonScalar _HashableTa = TypeVar("_HashableTa", bound=Hashable) +ByT = TypeVar( + "ByT", + str, + bytes, + datetime.date, + datetime.datetime, + datetime.timedelta, + np.datetime64, + np.timedelta64, + bool, + int, + float, + complex, + Timestamp, + Timedelta, + Scalar, + Period, + Interval[int], + Interval[float], + Interval[Timestamp], + Interval[Timedelta], + tuple, +) GroupByObjectNonScalar: TypeAlias = ( tuple | list[_HashableTa] diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index b94912442..22d2e1732 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -20,18 +20,22 @@ from typing import ( from matplotlib.axes import Axes as PlotAxes import numpy as np from pandas import ( + Period, Timedelta, Timestamp, ) from pandas.core.arraylike import OpsMixin from pandas.core.generic import NDFrame -from pandas.core.groupby.generic import ( - _DataFrameGroupByNonScalar, - _DataFrameGroupByScalar, -) +from pandas.core.groupby.generic import DataFrameGroupBy from pandas.core.groupby.grouper import Grouper from pandas.core.indexers import BaseIndexer from pandas.core.indexes.base import Index +from pandas.core.indexes.category import CategoricalIndex +from pandas.core.indexes.datetimes import DatetimeIndex +from pandas.core.indexes.interval import IntervalIndex +from pandas.core.indexes.multi import MultiIndex +from pandas.core.indexes.period import PeriodIndex +from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas.core.indexing import ( _iLocIndexer, _IndexSliceTuple, @@ -82,6 +86,7 @@ from pandas._typing import ( IndexLabel, IndexType, IntervalClosedType, + IntervalT, JoinHow, JsonFrameOrient, Label, @@ -1011,7 +1016,85 @@ class DataFrame(NDFrame, OpsMixin): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> _DataFrameGroupByScalar: ... + ) -> DataFrameGroupBy[Scalar]: ... + @overload + def groupby( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] + self, + by: DatetimeIndex, + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Timestamp]: ... + @overload + def groupby( # type: ignore[misc] + self, + by: TimedeltaIndex, + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Timedelta]: ... + @overload + def groupby( # type: ignore[misc] + self, + by: PeriodIndex, + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Period]: ... + @overload + def groupby( # type: ignore[misc] + self, + by: IntervalIndex[IntervalT], + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[IntervalT]: ... + @overload + def groupby( + self, + by: MultiIndex, + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[tuple]: ... + @overload + def groupby( + self, + by: CategoricalIndex | Index, + axis: Axis = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Any]: ... @overload def groupby( self, @@ -1024,7 +1107,7 @@ class DataFrame(NDFrame, OpsMixin): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> _DataFrameGroupByNonScalar: ... + ) -> DataFrameGroupBy[tuple]: ... def pivot( self, *, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index c17923a51..971fe2aca 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -30,6 +30,7 @@ from pandas._typing import ( AggFuncTypeBase, AggFuncTypeFrame, Axis, + ByT, Level, ListLike, RandomState, @@ -146,13 +147,7 @@ class SeriesGroupBy(GroupBy, Generic[S1]): def idxmax(self, axis: Axis = ..., skipna: bool = ...) -> Series: ... def idxmin(self, axis: Axis = ..., skipna: bool = ...) -> Series: ... -class _DataFrameGroupByScalar(DataFrameGroupBy): - def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ... - -class _DataFrameGroupByNonScalar(DataFrameGroupBy): - def __iter__(self) -> Iterator[tuple[tuple, DataFrame]]: ... - -class DataFrameGroupBy(GroupBy): +class DataFrameGroupBy(GroupBy, Generic[ByT]): def any(self, skipna: bool = ...) -> DataFrame: ... def all(self, skipna: bool = ...) -> DataFrame: ... # error: Overload 3 for "apply" will never be used because its parameters overlap overload 1 @@ -178,7 +173,7 @@ class DataFrameGroupBy(GroupBy): @overload def __getitem__(self, item: str) -> SeriesGroupBy: ... @overload - def __getitem__(self, item: list[str]) -> DataFrameGroupBy: ... + def __getitem__(self, item: list[str]) -> DataFrameGroupBy[ByT]: ... def count(self) -> DataFrame: ... def boxplot( self, @@ -364,3 +359,4 @@ class DataFrameGroupBy(GroupBy): dropna: bool = ..., ) -> Series[float]: ... def __getattr__(self, name: str) -> SeriesGroupBy: ... + def __iter__(self) -> Iterator[tuple[ByT, DataFrame]]: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 94f279563..1bc70a591 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1972,6 +1972,17 @@ def test_groupby_result() -> None: check(assert_type(index2, Scalar), int) check(assert_type(value2, pd.DataFrame), pd.DataFrame) + # GH 674 + # grouping by pd.MultiIndex should always resolve to a tuple as well + multi_index = pd.MultiIndex.from_frame(df[["a", "b"]]) + iterator3 = df.groupby(multi_index).__iter__() + assert_type(iterator3, Iterator[Tuple[Tuple, pd.DataFrame]]) + index3, value3 = next(iterator3) + assert_type((index3, value3), Tuple[Tuple, pd.DataFrame]) + + check(assert_type(index3, Tuple), tuple, int) + check(assert_type(value3, pd.DataFrame), pd.DataFrame) + # Want to make sure these cases are differentiated for (k1, k2), g in df.groupby(["a", "b"]): pass @@ -1979,6 +1990,89 @@ def test_groupby_result() -> None: for kk, g in df.groupby("a"): pass + for (k1, k2), g in df.groupby(multi_index): + pass + + +def test_groupby_result_for_scalar_indexes() -> None: + # GH 674 + dates = pd.date_range("2020-01-01", "2020-12-31") + df = pd.DataFrame({"date": dates, "days": 1}) + period_index = pd.PeriodIndex(df.date, freq="M") + iterator = df.groupby(period_index).__iter__() + assert_type(iterator, Iterator[Tuple[pd.Period, pd.DataFrame]]) + index, value = next(iterator) + assert_type((index, value), Tuple[pd.Period, pd.DataFrame]) + + check(assert_type(index, pd.Period), pd.Period) + check(assert_type(value, pd.DataFrame), pd.DataFrame) + + dt_index = pd.DatetimeIndex(dates) + iterator2 = df.groupby(dt_index).__iter__() + assert_type(iterator2, Iterator[Tuple[pd.Timestamp, pd.DataFrame]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[pd.Timestamp, pd.DataFrame]) + + check(assert_type(index2, pd.Timestamp), pd.Timestamp) + check(assert_type(value2, pd.DataFrame), pd.DataFrame) + + tdelta_index = pd.TimedeltaIndex(dates - pd.Timestamp("2020-01-01")) + iterator3 = df.groupby(tdelta_index).__iter__() + assert_type(iterator3, Iterator[Tuple[pd.Timedelta, pd.DataFrame]]) + index3, value3 = next(iterator3) + assert_type((index3, value3), Tuple[pd.Timedelta, pd.DataFrame]) + + check(assert_type(index3, pd.Timedelta), pd.Timedelta) + check(assert_type(value3, pd.DataFrame), pd.DataFrame) + + intervals: list[pd.Interval[pd.Timestamp]] = [ + pd.Interval(date, date + pd.DateOffset(days=1), closed="left") for date in dates + ] + interval_index = pd.IntervalIndex(intervals) + assert_type(interval_index, "pd.IntervalIndex[pd.Interval[pd.Timestamp]]") + iterator4 = df.groupby(interval_index).__iter__() + assert_type(iterator4, Iterator[Tuple["pd.Interval[pd.Timestamp]", pd.DataFrame]]) + index4, value4 = next(iterator4) + assert_type((index4, value4), Tuple["pd.Interval[pd.Timestamp]", pd.DataFrame]) + + check(assert_type(index4, "pd.Interval[pd.Timestamp]"), pd.Interval) + check(assert_type(value4, pd.DataFrame), pd.DataFrame) + + for p, g in df.groupby(period_index): + pass + + for dt, g in df.groupby(dt_index): + pass + + for tdelta, g in df.groupby(tdelta_index): + pass + + for interval, g in df.groupby(interval_index): + pass + + +def test_groupby_result_for_ambiguous_indexes() -> None: + # GH 674 + df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]}) + # this will use pd.Index which is ambiguous + iterator = df.groupby(df.index).__iter__() + assert_type(iterator, Iterator[Tuple[Any, pd.DataFrame]]) + index, value = next(iterator) + assert_type((index, value), Tuple[Any, pd.DataFrame]) + + check(assert_type(index, Any), int) + check(assert_type(value, pd.DataFrame), pd.DataFrame) + + # categorical indexes are also ambiguous + categorical_index = pd.CategoricalIndex(df.a) + iterator2 = df.groupby(categorical_index).__iter__() + assert_type(iterator2, Iterator[Tuple[Any, pd.DataFrame]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[Any, pd.DataFrame]) + + check(assert_type(index2, Any), int) + check(assert_type(value2, pd.DataFrame), pd.DataFrame) + def test_setitem_list(): # GH 153