Skip to content

Commit 002f0cc

Browse files
authored
Refactors Series.groupby to match DataFrame.groupby (#690)
* Refactors `Series.groupby` to match `DataFrame.groupby` * Fixes failing tests on Windows due to integer size. * Actually fixes failing tests on Windows * Also avoids ambiguous integer size in the MultiIndex * Avoids platform specific int size complications by using `np.integer` * Adds missing type parameters to `DataFrameGroupBy.__getattr__`
1 parent 7a8befd commit 002f0cc

File tree

5 files changed

+216
-38
lines changed

5 files changed

+216
-38
lines changed

pandas-stubs/_typing.pyi

+1-2
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,11 @@ GroupByObjectNonScalar: TypeAlias = (
438438
| list[np.ndarray]
439439
| Mapping[Label, Any]
440440
| list[Mapping[Label, Any]]
441-
| Index
442441
| list[Index]
443442
| Grouper
444443
| list[Grouper]
445444
)
446-
GroupByObject: TypeAlias = Scalar | GroupByObjectNonScalar
445+
GroupByObject: TypeAlias = Scalar | Index | GroupByObjectNonScalar
447446

448447
StataDateFormat: TypeAlias = Literal[
449448
"tc",

pandas-stubs/core/frame.pyi

+5-18
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ class DataFrame(NDFrame, OpsMixin):
10181018
dropna: _bool = ...,
10191019
) -> DataFrameGroupBy[Scalar]: ...
10201020
@overload
1021-
def groupby( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
1021+
def groupby(
10221022
self,
10231023
by: DatetimeIndex,
10241024
axis: Axis = ...,
@@ -1031,7 +1031,7 @@ class DataFrame(NDFrame, OpsMixin):
10311031
dropna: _bool = ...,
10321032
) -> DataFrameGroupBy[Timestamp]: ...
10331033
@overload
1034-
def groupby( # type: ignore[misc]
1034+
def groupby(
10351035
self,
10361036
by: TimedeltaIndex,
10371037
axis: Axis = ...,
@@ -1044,7 +1044,7 @@ class DataFrame(NDFrame, OpsMixin):
10441044
dropna: _bool = ...,
10451045
) -> DataFrameGroupBy[Timedelta]: ...
10461046
@overload
1047-
def groupby( # type: ignore[misc]
1047+
def groupby(
10481048
self,
10491049
by: PeriodIndex,
10501050
axis: Axis = ...,
@@ -1057,7 +1057,7 @@ class DataFrame(NDFrame, OpsMixin):
10571057
dropna: _bool = ...,
10581058
) -> DataFrameGroupBy[Period]: ...
10591059
@overload
1060-
def groupby( # type: ignore[misc]
1060+
def groupby(
10611061
self,
10621062
by: IntervalIndex[IntervalT],
10631063
axis: Axis = ...,
@@ -1072,7 +1072,7 @@ class DataFrame(NDFrame, OpsMixin):
10721072
@overload
10731073
def groupby(
10741074
self,
1075-
by: MultiIndex,
1075+
by: MultiIndex | GroupByObjectNonScalar | None = ...,
10761076
axis: Axis = ...,
10771077
level: Level | None = ...,
10781078
as_index: _bool = ...,
@@ -1095,19 +1095,6 @@ class DataFrame(NDFrame, OpsMixin):
10951095
observed: _bool = ...,
10961096
dropna: _bool = ...,
10971097
) -> DataFrameGroupBy[Any]: ...
1098-
@overload
1099-
def groupby(
1100-
self,
1101-
by: GroupByObjectNonScalar | None = ...,
1102-
axis: Axis = ...,
1103-
level: Level | None = ...,
1104-
as_index: _bool = ...,
1105-
sort: _bool = ...,
1106-
group_keys: _bool = ...,
1107-
squeeze: _bool = ...,
1108-
observed: _bool = ...,
1109-
dropna: _bool = ...,
1110-
) -> DataFrameGroupBy[tuple]: ...
11111098
def pivot(
11121099
self,
11131100
*,

pandas-stubs/core/groupby/generic.pyi

+4-9
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,7 @@ class NamedAgg(NamedTuple):
4646

4747
def generate_property(name: str, klass: type[NDFrame]): ...
4848

49-
class _SeriesGroupByScalar(SeriesGroupBy[S1]):
50-
def __iter__(self) -> Iterator[tuple[Scalar, Series]]: ...
51-
52-
class _SeriesGroupByNonScalar(SeriesGroupBy[S1]):
53-
def __iter__(self) -> Iterator[tuple[tuple, Series]]: ...
54-
55-
class SeriesGroupBy(GroupBy, Generic[S1]):
49+
class SeriesGroupBy(GroupBy, Generic[S1, ByT]):
5650
def any(self, skipna: bool = ...) -> Series[bool]: ...
5751
def all(self, skipna: bool = ...) -> Series[bool]: ...
5852
def apply(self, func, *args, **kwargs) -> Series: ...
@@ -146,6 +140,7 @@ class SeriesGroupBy(GroupBy, Generic[S1]):
146140
) -> AxesSubplot: ...
147141
def idxmax(self, axis: Axis = ..., skipna: bool = ...) -> Series: ...
148142
def idxmin(self, axis: Axis = ..., skipna: bool = ...) -> Series: ...
143+
def __iter__(self) -> Iterator[tuple[ByT, Series[S1]]]: ...
149144

150145
class DataFrameGroupBy(GroupBy, Generic[ByT]):
151146
def any(self, skipna: bool = ...) -> DataFrame: ...
@@ -171,7 +166,7 @@ class DataFrameGroupBy(GroupBy, Generic[ByT]):
171166
) -> DataFrame: ...
172167
def nunique(self, dropna: bool = ...) -> DataFrame: ...
173168
@overload
174-
def __getitem__(self, item: str) -> SeriesGroupBy: ...
169+
def __getitem__(self, item: str) -> SeriesGroupBy[Any, ByT]: ...
175170
@overload
176171
def __getitem__(self, item: list[str]) -> DataFrameGroupBy[ByT]: ...
177172
def count(self) -> DataFrame: ...
@@ -358,5 +353,5 @@ class DataFrameGroupBy(GroupBy, Generic[ByT]):
358353
ascending: bool = ...,
359354
dropna: bool = ...,
360355
) -> Series[float]: ...
361-
def __getattr__(self, name: str) -> SeriesGroupBy: ...
356+
def __getattr__(self, name: str) -> SeriesGroupBy[Any, ByT]: ...
362357
def __iter__(self) -> Iterator[tuple[ByT, DataFrame]]: ...

pandas-stubs/core/series.pyi

+72-9
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ from pandas import (
3939
from pandas.core.arrays.base import ExtensionArray
4040
from pandas.core.arrays.categorical import CategoricalAccessor
4141
from pandas.core.arrays.interval import IntervalArray
42-
from pandas.core.groupby.generic import (
43-
_SeriesGroupByNonScalar,
44-
_SeriesGroupByScalar,
45-
)
42+
from pandas.core.groupby.generic import SeriesGroupBy
4643
from pandas.core.indexers import BaseIndexer
4744
from pandas.core.indexes.accessors import (
4845
CombinedDatetimelikeProperties,
@@ -51,8 +48,10 @@ from pandas.core.indexes.accessors import (
5148
TimestampProperties,
5249
)
5350
from pandas.core.indexes.base import Index
51+
from pandas.core.indexes.category import CategoricalIndex
5452
from pandas.core.indexes.datetimes import DatetimeIndex
5553
from pandas.core.indexes.interval import IntervalIndex
54+
from pandas.core.indexes.multi import MultiIndex
5655
from pandas.core.indexes.period import PeriodIndex
5756
from pandas.core.indexes.timedeltas import TimedeltaIndex
5857
from pandas.core.indexing import (
@@ -65,7 +64,6 @@ from pandas.core.strings import StringMethods
6564
from pandas.core.window import (
6665
Expanding,
6766
ExponentialMovingWindow,
68-
Rolling,
6967
)
7068
from pandas.core.window.rolling import (
7169
Rolling,
@@ -113,6 +111,7 @@ from pandas._typing import (
113111
IndexingInt,
114112
IntDtypeArg,
115113
IntervalClosedType,
114+
IntervalT,
116115
JoinHow,
117116
JsonSeriesOrient,
118117
Level,
@@ -144,7 +143,6 @@ from pandas.plotting import PlotAccessor
144143
from .base import IndexOpsMixin
145144
from .frame import DataFrame
146145
from .generic import NDFrame
147-
from .indexes.multi import MultiIndex
148146
from .indexing import (
149147
_iLocIndexer,
150148
_LocIndexer,
@@ -537,11 +535,76 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
537535
squeeze: _bool = ...,
538536
observed: _bool = ...,
539537
dropna: _bool = ...,
540-
) -> _SeriesGroupByScalar[S1]: ...
538+
) -> SeriesGroupBy[S1, Scalar]: ...
539+
@overload
540+
def groupby(
541+
self,
542+
by: DatetimeIndex,
543+
axis: AxisIndex = ...,
544+
level: Level | None = ...,
545+
as_index: _bool = ...,
546+
sort: _bool = ...,
547+
group_keys: _bool = ...,
548+
squeeze: _bool = ...,
549+
observed: _bool = ...,
550+
dropna: _bool = ...,
551+
) -> SeriesGroupBy[S1, Timestamp]: ...
552+
@overload
553+
def groupby(
554+
self,
555+
by: TimedeltaIndex,
556+
axis: AxisIndex = ...,
557+
level: Level | None = ...,
558+
as_index: _bool = ...,
559+
sort: _bool = ...,
560+
group_keys: _bool = ...,
561+
squeeze: _bool = ...,
562+
observed: _bool = ...,
563+
dropna: _bool = ...,
564+
) -> SeriesGroupBy[S1, Timedelta]: ...
565+
@overload
566+
def groupby(
567+
self,
568+
by: PeriodIndex,
569+
axis: AxisIndex = ...,
570+
level: Level | None = ...,
571+
as_index: _bool = ...,
572+
sort: _bool = ...,
573+
group_keys: _bool = ...,
574+
squeeze: _bool = ...,
575+
observed: _bool = ...,
576+
dropna: _bool = ...,
577+
) -> SeriesGroupBy[S1, Period]: ...
578+
@overload
579+
def groupby(
580+
self,
581+
by: IntervalIndex[IntervalT],
582+
axis: AxisIndex = ...,
583+
level: Level | None = ...,
584+
as_index: _bool = ...,
585+
sort: _bool = ...,
586+
group_keys: _bool = ...,
587+
squeeze: _bool = ...,
588+
observed: _bool = ...,
589+
dropna: _bool = ...,
590+
) -> SeriesGroupBy[S1, IntervalT]: ...
591+
@overload
592+
def groupby(
593+
self,
594+
by: MultiIndex | GroupByObjectNonScalar = ...,
595+
axis: AxisIndex = ...,
596+
level: Level | None = ...,
597+
as_index: _bool = ...,
598+
sort: _bool = ...,
599+
group_keys: _bool = ...,
600+
squeeze: _bool = ...,
601+
observed: _bool = ...,
602+
dropna: _bool = ...,
603+
) -> SeriesGroupBy[S1, tuple]: ...
541604
@overload
542605
def groupby(
543606
self,
544-
by: GroupByObjectNonScalar = ...,
607+
by: CategoricalIndex | Index,
545608
axis: AxisIndex = ...,
546609
level: Level | None = ...,
547610
as_index: _bool = ...,
@@ -550,7 +613,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
550613
squeeze: _bool = ...,
551614
observed: _bool = ...,
552615
dropna: _bool = ...,
553-
) -> _SeriesGroupByNonScalar[S1]: ...
616+
) -> SeriesGroupBy[S1, Any]: ...
554617
# need the ignore because None is Hashable
555618
@overload
556619
def count(self, level: None = ...) -> int: ... # type: ignore[misc]

tests/test_series.py

+134
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,140 @@ def test_types_groupby_methods() -> None:
565565
check(assert_type(s.groupby(level=0).idxmin(), pd.Series), pd.Series)
566566

567567

568+
def test_groupby_result() -> None:
569+
# GH 142
570+
# since there are no columns in a Series, groupby name only works
571+
# with a named index, we use a MultiIndex, so we can group by more
572+
# than one level and test the non-scalar case
573+
multi_index = pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=["a", "b"])
574+
s = pd.Series([0, 1, 2], index=multi_index, dtype=int)
575+
iterator = s.groupby(["a", "b"]).__iter__()
576+
assert_type(iterator, Iterator[Tuple[Tuple, "pd.Series[int]"]])
577+
index, value = next(iterator)
578+
assert_type((index, value), Tuple[Tuple, "pd.Series[int]"])
579+
580+
check(assert_type(index, Tuple), tuple, np.integer)
581+
check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer)
582+
583+
iterator2 = s.groupby("a").__iter__()
584+
assert_type(iterator2, Iterator[Tuple[Scalar, "pd.Series[int]"]])
585+
index2, value2 = next(iterator2)
586+
assert_type((index2, value2), Tuple[Scalar, "pd.Series[int]"])
587+
588+
check(assert_type(index2, Scalar), int)
589+
check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer)
590+
591+
# GH 674
592+
# grouping by pd.MultiIndex should always resolve to a tuple as well
593+
iterator3 = s.groupby(multi_index).__iter__()
594+
assert_type(iterator3, Iterator[Tuple[Tuple, "pd.Series[int]"]])
595+
index3, value3 = next(iterator3)
596+
assert_type((index3, value3), Tuple[Tuple, "pd.Series[int]"])
597+
598+
check(assert_type(index3, Tuple), tuple, int)
599+
check(assert_type(value3, "pd.Series[int]"), pd.Series, np.integer)
600+
601+
# Want to make sure these cases are differentiated
602+
for (k1, k2), g in s.groupby(["a", "b"]):
603+
pass
604+
605+
for kk, g in s.groupby("a"):
606+
pass
607+
608+
for (k1, k2), g in s.groupby(multi_index):
609+
pass
610+
611+
612+
def test_groupby_result_for_scalar_indexes() -> None:
613+
# GH 674
614+
s = pd.Series([0, 1, 2], dtype=int)
615+
dates = pd.Series(
616+
[
617+
pd.Timestamp("2020-01-01"),
618+
pd.Timestamp("2020-01-15"),
619+
pd.Timestamp("2020-02-01"),
620+
],
621+
dtype="datetime64[ns]",
622+
)
623+
624+
period_index = pd.PeriodIndex(dates, freq="M")
625+
iterator = s.groupby(period_index).__iter__()
626+
assert_type(iterator, Iterator[Tuple[pd.Period, "pd.Series[int]"]])
627+
index, value = next(iterator)
628+
assert_type((index, value), Tuple[pd.Period, "pd.Series[int]"])
629+
630+
check(assert_type(index, pd.Period), pd.Period)
631+
check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer)
632+
633+
dt_index = pd.DatetimeIndex(dates)
634+
iterator2 = s.groupby(dt_index).__iter__()
635+
assert_type(iterator2, Iterator[Tuple[pd.Timestamp, "pd.Series[int]"]])
636+
index2, value2 = next(iterator2)
637+
assert_type((index2, value2), Tuple[pd.Timestamp, "pd.Series[int]"])
638+
639+
check(assert_type(index2, pd.Timestamp), pd.Timestamp)
640+
check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer)
641+
642+
tdelta_index = pd.TimedeltaIndex(dates - pd.Timestamp("2020-01-01"))
643+
iterator3 = s.groupby(tdelta_index).__iter__()
644+
assert_type(iterator3, Iterator[Tuple[pd.Timedelta, "pd.Series[int]"]])
645+
index3, value3 = next(iterator3)
646+
assert_type((index3, value3), Tuple[pd.Timedelta, "pd.Series[int]"])
647+
648+
check(assert_type(index3, pd.Timedelta), pd.Timedelta)
649+
check(assert_type(value3, "pd.Series[int]"), pd.Series, np.integer)
650+
651+
intervals: list[pd.Interval[pd.Timestamp]] = [
652+
pd.Interval(date, date + pd.DateOffset(days=1), closed="left") for date in dates
653+
]
654+
interval_index = pd.IntervalIndex(intervals)
655+
assert_type(interval_index, "pd.IntervalIndex[pd.Interval[pd.Timestamp]]")
656+
iterator4 = s.groupby(interval_index).__iter__()
657+
assert_type(
658+
iterator4, Iterator[Tuple["pd.Interval[pd.Timestamp]", "pd.Series[int]"]]
659+
)
660+
index4, value4 = next(iterator4)
661+
assert_type((index4, value4), Tuple["pd.Interval[pd.Timestamp]", "pd.Series[int]"])
662+
663+
check(assert_type(index4, "pd.Interval[pd.Timestamp]"), pd.Interval)
664+
check(assert_type(value4, "pd.Series[int]"), pd.Series, np.integer)
665+
666+
for p, g in s.groupby(period_index):
667+
pass
668+
669+
for dt, g in s.groupby(dt_index):
670+
pass
671+
672+
for tdelta, g in s.groupby(tdelta_index):
673+
pass
674+
675+
for interval, g in s.groupby(interval_index):
676+
pass
677+
678+
679+
def test_groupby_result_for_ambiguous_indexes() -> None:
680+
# GH 674
681+
s = pd.Series([0, 1, 2], index=["a", "b", "a"], dtype=int)
682+
# this will use pd.Index which is ambiguous
683+
iterator = s.groupby(s.index).__iter__()
684+
assert_type(iterator, Iterator[Tuple[Any, "pd.Series[int]"]])
685+
index, value = next(iterator)
686+
assert_type((index, value), Tuple[Any, "pd.Series[int]"])
687+
688+
check(assert_type(index, Any), str)
689+
check(assert_type(value, "pd.Series[int]"), pd.Series, np.integer)
690+
691+
# categorical indexes are also ambiguous
692+
categorical_index = pd.CategoricalIndex(s.index)
693+
iterator2 = s.groupby(categorical_index).__iter__()
694+
assert_type(iterator2, Iterator[Tuple[Any, "pd.Series[int]"]])
695+
index2, value2 = next(iterator2)
696+
assert_type((index2, value2), Tuple[Any, "pd.Series[int]"])
697+
698+
check(assert_type(index2, Any), str)
699+
check(assert_type(value2, "pd.Series[int]"), pd.Series, np.integer)
700+
701+
568702
def test_types_groupby_agg() -> None:
569703
s = pd.Series([4, 2, 1, 8], index=["a", "b", "a", "b"])
570704
check(assert_type(s.groupby(level=0).agg("sum"), pd.Series), pd.Series)

0 commit comments

Comments
 (0)