diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 0e3071502..8eca9d0c5 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -438,12 +438,11 @@ GroupByObjectNonScalar: TypeAlias = ( | list[np.ndarray] | Mapping[Label, Any] | list[Mapping[Label, Any]] - | Index | list[Index] | Grouper | list[Grouper] ) -GroupByObject: TypeAlias = Scalar | GroupByObjectNonScalar +GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar StataDateFormat: TypeAlias = Literal[ "tc", diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 343a6c4c8..e7b7f5abf 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1018,7 +1018,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Scalar]: ... @overload - def groupby( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload] + def groupby( self, by: DatetimeIndex, axis: Axis = ..., @@ -1031,7 +1031,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Timestamp]: ... @overload - def groupby( # type: ignore[misc] + def groupby( self, by: TimedeltaIndex, axis: Axis = ..., @@ -1044,7 +1044,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Timedelta]: ... @overload - def groupby( # type: ignore[misc] + def groupby( self, by: PeriodIndex, axis: Axis = ..., @@ -1057,7 +1057,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Period]: ... @overload - def groupby( # type: ignore[misc] + def groupby( self, by: IntervalIndex[IntervalT], axis: Axis = ..., @@ -1072,7 +1072,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def groupby( self, - by: MultiIndex, + by: MultiIndex | GroupByObjectNonScalar | None = ..., axis: Axis = ..., level: Level | None = ..., as_index: _bool = ..., @@ -1095,19 +1095,6 @@ class DataFrame(NDFrame, OpsMixin): observed: _bool = ..., dropna: _bool = ..., ) -> DataFrameGroupBy[Any]: ... - @overload - def groupby( - self, - by: GroupByObjectNonScalar | None = ..., - axis: Axis = ..., - level: Level | None = ..., - as_index: _bool = ..., - sort: _bool = ..., - group_keys: _bool = ..., - squeeze: _bool = ..., - observed: _bool = ..., - dropna: _bool = ..., - ) -> DataFrameGroupBy[tuple]: ... def pivot( self, *, diff --git a/pandas-stubs/core/groupby/generic.pyi b/pandas-stubs/core/groupby/generic.pyi index 971fe2aca..9b1d93241 100644 --- a/pandas-stubs/core/groupby/generic.pyi +++ b/pandas-stubs/core/groupby/generic.pyi @@ -46,13 +46,7 @@ class NamedAgg(NamedTuple): def generate_property(name: str, klass: type[NDFrame]): ... -class _SeriesGroupByScalar(SeriesGroupBy[S1]): - def __iter__(self) -> Iterator[tuple[Scalar, Series]]: ... - -class _SeriesGroupByNonScalar(SeriesGroupBy[S1]): - def __iter__(self) -> Iterator[tuple[tuple, Series]]: ... - -class SeriesGroupBy(GroupBy, Generic[S1]): +class SeriesGroupBy(GroupBy, Generic[S1, ByT]): def any(self, skipna: bool = ...) -> Series[bool]: ... def all(self, skipna: bool = ...) -> Series[bool]: ... def apply(self, func, *args, **kwargs) -> Series: ... @@ -146,6 +140,7 @@ class SeriesGroupBy(GroupBy, Generic[S1]): ) -> AxesSubplot: ... def idxmax(self, axis: Axis = ..., skipna: bool = ...) -> Series: ... def idxmin(self, axis: Axis = ..., skipna: bool = ...) -> Series: ... + def __iter__(self) -> Iterator[tuple[ByT, Series[S1]]]: ... class DataFrameGroupBy(GroupBy, Generic[ByT]): def any(self, skipna: bool = ...) -> DataFrame: ... @@ -171,7 +166,7 @@ class DataFrameGroupBy(GroupBy, Generic[ByT]): ) -> DataFrame: ... def nunique(self, dropna: bool = ...) -> DataFrame: ... @overload - def __getitem__(self, item: str) -> SeriesGroupBy: ... + def __getitem__(self, item: str) -> SeriesGroupBy[Any, ByT]: ... @overload def __getitem__(self, item: list[str]) -> DataFrameGroupBy[ByT]: ... def count(self) -> DataFrame: ... @@ -358,5 +353,5 @@ class DataFrameGroupBy(GroupBy, Generic[ByT]): ascending: bool = ..., dropna: bool = ..., ) -> Series[float]: ... - def __getattr__(self, name: str) -> SeriesGroupBy: ... + def __getattr__(self, name: str) -> SeriesGroupBy[Any, ByT]: ... def __iter__(self) -> Iterator[tuple[ByT, DataFrame]]: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 41c1e9f27..29fbeb489 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -39,10 +39,7 @@ from pandas import ( from pandas.core.arrays.base import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.interval import IntervalArray -from pandas.core.groupby.generic import ( - _SeriesGroupByNonScalar, - _SeriesGroupByScalar, -) +from pandas.core.groupby.generic import SeriesGroupBy from pandas.core.indexers import BaseIndexer from pandas.core.indexes.accessors import ( CombinedDatetimelikeProperties, @@ -51,8 +48,10 @@ from pandas.core.indexes.accessors import ( TimestampProperties, ) from pandas.core.indexes.base import Index +from pandas.core.indexes.category import CategoricalIndex from pandas.core.indexes.datetimes import DatetimeIndex from pandas.core.indexes.interval import IntervalIndex +from pandas.core.indexes.multi import MultiIndex from pandas.core.indexes.period import PeriodIndex from pandas.core.indexes.timedeltas import TimedeltaIndex from pandas.core.indexing import ( @@ -65,7 +64,6 @@ from pandas.core.strings import StringMethods from pandas.core.window import ( Expanding, ExponentialMovingWindow, - Rolling, ) from pandas.core.window.rolling import ( Rolling, @@ -113,6 +111,7 @@ from pandas._typing import ( IndexingInt, IntDtypeArg, IntervalClosedType, + IntervalT, JoinHow, JsonSeriesOrient, Level, @@ -144,7 +143,6 @@ from pandas.plotting import PlotAccessor from .base import IndexOpsMixin from .frame import DataFrame from .generic import NDFrame -from .indexes.multi import MultiIndex from .indexing import ( _iLocIndexer, _LocIndexer, @@ -537,11 +535,76 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> _SeriesGroupByScalar[S1]: ... + ) -> SeriesGroupBy[S1, Scalar]: ... + @overload + def groupby( + self, + by: DatetimeIndex, + axis: AxisIndex = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy[S1, Timestamp]: ... + @overload + def groupby( + self, + by: TimedeltaIndex, + axis: AxisIndex = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy[S1, Timedelta]: ... + @overload + def groupby( + self, + by: PeriodIndex, + axis: AxisIndex = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy[S1, Period]: ... + @overload + def groupby( + self, + by: IntervalIndex[IntervalT], + axis: AxisIndex = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy[S1, IntervalT]: ... + @overload + def groupby( + self, + by: MultiIndex | GroupByObjectNonScalar = ..., + axis: AxisIndex = ..., + level: Level | None = ..., + as_index: _bool = ..., + sort: _bool = ..., + group_keys: _bool = ..., + squeeze: _bool = ..., + observed: _bool = ..., + dropna: _bool = ..., + ) -> SeriesGroupBy[S1, tuple]: ... @overload def groupby( self, - by: GroupByObjectNonScalar = ..., + by: CategoricalIndex | Index, axis: AxisIndex = ..., level: Level | None = ..., as_index: _bool = ..., @@ -550,7 +613,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): squeeze: _bool = ..., observed: _bool = ..., dropna: _bool = ..., - ) -> _SeriesGroupByNonScalar[S1]: ... + ) -> SeriesGroupBy[S1, Any]: ... # need the ignore because None is Hashable @overload def count(self, level: None = ...) -> int: ... # type: ignore[misc] diff --git a/tests/test_series.py b/tests/test_series.py index f76c2662d..42a64112a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -565,6 +565,140 @@ def test_types_groupby_methods() -> None: check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series) +def test_groupby_result() -> None: + # GH 142 + # since there are no columns in a Series, groupby name only works + # with a named index, we use a MultiIndex, so we can group by more + # than one level and test the non-scalar case + multi_index = pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=["a", "b"]) + s = pd.Series([0, 1, 2], index=multi_index, dtype=int) + iterator = s.groupby(["a", "b"]).__iter__() + assert_type(iterator, Iterator[Tuple[Tuple, "pd.Series[int]"]]) + index, value = next(iterator) + assert_type((index, value), Tuple[Tuple, "pd.Series[int]"]) + + check(assert_type(index, Tuple), tuple, np.integer) + check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer) + + iterator2 = s.groupby("a").__iter__() + assert_type(iterator2, Iterator[Tuple[Scalar, "pd.Series[int]"]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[Scalar, "pd.Series[int]"]) + + check(assert_type(index2, Scalar), int) + check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer) + + # GH 674 + # grouping by pd.MultiIndex should always resolve to a tuple as well + iterator3 = s.groupby(multi_index).__iter__() + assert_type(iterator3, Iterator[Tuple[Tuple, "pd.Series[int]"]]) + index3, value3 = next(iterator3) + assert_type((index3, value3), Tuple[Tuple, "pd.Series[int]"]) + + check(assert_type(index3, Tuple), tuple, int) + check(assert_type(value3, "pd.Series[int]"), pd.Series, np.integer) + + # Want to make sure these cases are differentiated + for (k1, k2), g in s.groupby(["a", "b"]): + pass + + for kk, g in s.groupby("a"): + pass + + for (k1, k2), g in s.groupby(multi_index): + pass + + +def test_groupby_result_for_scalar_indexes() -> None: + # GH 674 + s = pd.Series([0, 1, 2], dtype=int) + dates = pd.Series( + [ + pd.Timestamp("2020-01-01"), + pd.Timestamp("2020-01-15"), + pd.Timestamp("2020-02-01"), + ], + dtype="datetime64[ns]", + ) + + period_index = pd.PeriodIndex(dates, freq="M") + iterator = s.groupby(period_index).__iter__() + assert_type(iterator, Iterator[Tuple[pd.Period, "pd.Series[int]"]]) + index, value = next(iterator) + assert_type((index, value), Tuple[pd.Period, "pd.Series[int]"]) + + check(assert_type(index, pd.Period), pd.Period) + check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer) + + dt_index = pd.DatetimeIndex(dates) + iterator2 = s.groupby(dt_index).__iter__() + assert_type(iterator2, Iterator[Tuple[pd.Timestamp, "pd.Series[int]"]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[pd.Timestamp, "pd.Series[int]"]) + + check(assert_type(index2, pd.Timestamp), pd.Timestamp) + check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer) + + tdelta_index = pd.TimedeltaIndex(dates - pd.Timestamp("2020-01-01")) + iterator3 = s.groupby(tdelta_index).__iter__() + assert_type(iterator3, Iterator[Tuple[pd.Timedelta, "pd.Series[int]"]]) + index3, value3 = next(iterator3) + assert_type((index3, value3), Tuple[pd.Timedelta, "pd.Series[int]"]) + + check(assert_type(index3, pd.Timedelta), pd.Timedelta) + check(assert_type(value3, "pd.Series[int]"), pd.Series, np.integer) + + intervals: list[pd.Interval[pd.Timestamp]] = [ + pd.Interval(date, date + pd.DateOffset(days=1), closed="left") for date in dates + ] + interval_index = pd.IntervalIndex(intervals) + assert_type(interval_index, "pd.IntervalIndex[pd.Interval[pd.Timestamp]]") + iterator4 = s.groupby(interval_index).__iter__() + assert_type( + iterator4, Iterator[Tuple["pd.Interval[pd.Timestamp]", "pd.Series[int]"]] + ) + index4, value4 = next(iterator4) + assert_type((index4, value4), Tuple["pd.Interval[pd.Timestamp]", "pd.Series[int]"]) + + check(assert_type(index4, "pd.Interval[pd.Timestamp]"), pd.Interval) + check(assert_type(value4, "pd.Series[int]"), pd.Series, np.integer) + + for p, g in s.groupby(period_index): + pass + + for dt, g in s.groupby(dt_index): + pass + + for tdelta, g in s.groupby(tdelta_index): + pass + + for interval, g in s.groupby(interval_index): + pass + + +def test_groupby_result_for_ambiguous_indexes() -> None: + # GH 674 + s = pd.Series([0, 1, 2], index=["a", "b", "a"], dtype=int) + # this will use pd.Index which is ambiguous + iterator = s.groupby(s.index).__iter__() + assert_type(iterator, Iterator[Tuple[Any, "pd.Series[int]"]]) + index, value = next(iterator) + assert_type((index, value), Tuple[Any, "pd.Series[int]"]) + + check(assert_type(index, Any), str) + check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer) + + # categorical indexes are also ambiguous + categorical_index = pd.CategoricalIndex(s.index) + iterator2 = s.groupby(categorical_index).__iter__() + assert_type(iterator2, Iterator[Tuple[Any, "pd.Series[int]"]]) + index2, value2 = next(iterator2) + assert_type((index2, value2), Tuple[Any, "pd.Series[int]"]) + + check(assert_type(index2, Any), str) + check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer) + + def test_types_groupby_agg() -> None: s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"]) check(assert_type(s.groupby(level=0).agg("sum"), pd.Series), pd.Series)