Skip to content

Commit bd817d2

Browse files
authored
Extends DataFrame.groupby overloads to recognize some scalar index types (#679)
* Extends groupby on dataframe to recognize some scalar index types * Make `GroupByDataFrame` generic. Add new `TypeVar` `ByT`. Add additional overloads to `DataFrame.groupby` to cover all index types
1 parent 47d2ce4 commit bd817d2

File tree

4 files changed

+210
-14
lines changed

4 files changed

+210
-14
lines changed

pandas-stubs/_typing.pyi

+23
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,29 @@ Function: TypeAlias = np.ufunc | Callable[..., Any]
404404
# shared HashableT and HashableT#. This one can be used if the identical
405405
# type is need in a function that uses GroupByObjectNonScalar
406406
_HashableTa = TypeVar("_HashableTa", bound=Hashable)
407+
ByT = TypeVar(
408+
"ByT",
409+
str,
410+
bytes,
411+
datetime.date,
412+
datetime.datetime,
413+
datetime.timedelta,
414+
np.datetime64,
415+
np.timedelta64,
416+
bool,
417+
int,
418+
float,
419+
complex,
420+
Timestamp,
421+
Timedelta,
422+
Scalar,
423+
Period,
424+
Interval[int],
425+
Interval[float],
426+
Interval[Timestamp],
427+
Interval[Timedelta],
428+
tuple,
429+
)
407430
GroupByObjectNonScalar: TypeAlias = (
408431
tuple
409432
| list[_HashableTa]

pandas-stubs/core/frame.pyi

+89-6
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ from typing import (
2020
from matplotlib.axes import Axes as PlotAxes
2121
import numpy as np
2222
from pandas import (
23+
Period,
2324
Timedelta,
2425
Timestamp,
2526
)
2627
from pandas.core.arraylike import OpsMixin
2728
from pandas.core.generic import NDFrame
28-
from pandas.core.groupby.generic import (
29-
_DataFrameGroupByNonScalar,
30-
_DataFrameGroupByScalar,
31-
)
29+
from pandas.core.groupby.generic import DataFrameGroupBy
3230
from pandas.core.groupby.grouper import Grouper
3331
from pandas.core.indexers import BaseIndexer
3432
from pandas.core.indexes.base import Index
33+
from pandas.core.indexes.category import CategoricalIndex
34+
from pandas.core.indexes.datetimes import DatetimeIndex
35+
from pandas.core.indexes.interval import IntervalIndex
36+
from pandas.core.indexes.multi import MultiIndex
37+
from pandas.core.indexes.period import PeriodIndex
38+
from pandas.core.indexes.timedeltas import TimedeltaIndex
3539
from pandas.core.indexing import (
3640
_iLocIndexer,
3741
_IndexSliceTuple,
@@ -82,6 +86,7 @@ from pandas._typing import (
8286
IndexLabel,
8387
IndexType,
8488
IntervalClosedType,
89+
IntervalT,
8590
JoinHow,
8691
JsonFrameOrient,
8792
Label,
@@ -1011,7 +1016,85 @@ class DataFrame(NDFrame, OpsMixin):
10111016
squeeze: _bool = ...,
10121017
observed: _bool = ...,
10131018
dropna: _bool = ...,
1014-
) -> _DataFrameGroupByScalar: ...
1019+
) -> DataFrameGroupBy[Scalar]: ...
1020+
@overload
1021+
def groupby( # type: ignore[misc] # pyright: ignore[reportOverlappingOverload]
1022+
self,
1023+
by: DatetimeIndex,
1024+
axis: Axis = ...,
1025+
level: Level | None = ...,
1026+
as_index: _bool = ...,
1027+
sort: _bool = ...,
1028+
group_keys: _bool = ...,
1029+
squeeze: _bool = ...,
1030+
observed: _bool = ...,
1031+
dropna: _bool = ...,
1032+
) -> DataFrameGroupBy[Timestamp]: ...
1033+
@overload
1034+
def groupby( # type: ignore[misc]
1035+
self,
1036+
by: TimedeltaIndex,
1037+
axis: Axis = ...,
1038+
level: Level | None = ...,
1039+
as_index: _bool = ...,
1040+
sort: _bool = ...,
1041+
group_keys: _bool = ...,
1042+
squeeze: _bool = ...,
1043+
observed: _bool = ...,
1044+
dropna: _bool = ...,
1045+
) -> DataFrameGroupBy[Timedelta]: ...
1046+
@overload
1047+
def groupby( # type: ignore[misc]
1048+
self,
1049+
by: PeriodIndex,
1050+
axis: Axis = ...,
1051+
level: Level | None = ...,
1052+
as_index: _bool = ...,
1053+
sort: _bool = ...,
1054+
group_keys: _bool = ...,
1055+
squeeze: _bool = ...,
1056+
observed: _bool = ...,
1057+
dropna: _bool = ...,
1058+
) -> DataFrameGroupBy[Period]: ...
1059+
@overload
1060+
def groupby( # type: ignore[misc]
1061+
self,
1062+
by: IntervalIndex[IntervalT],
1063+
axis: Axis = ...,
1064+
level: Level | None = ...,
1065+
as_index: _bool = ...,
1066+
sort: _bool = ...,
1067+
group_keys: _bool = ...,
1068+
squeeze: _bool = ...,
1069+
observed: _bool = ...,
1070+
dropna: _bool = ...,
1071+
) -> DataFrameGroupBy[IntervalT]: ...
1072+
@overload
1073+
def groupby(
1074+
self,
1075+
by: MultiIndex,
1076+
axis: Axis = ...,
1077+
level: Level | None = ...,
1078+
as_index: _bool = ...,
1079+
sort: _bool = ...,
1080+
group_keys: _bool = ...,
1081+
squeeze: _bool = ...,
1082+
observed: _bool = ...,
1083+
dropna: _bool = ...,
1084+
) -> DataFrameGroupBy[tuple]: ...
1085+
@overload
1086+
def groupby(
1087+
self,
1088+
by: CategoricalIndex | Index,
1089+
axis: Axis = ...,
1090+
level: Level | None = ...,
1091+
as_index: _bool = ...,
1092+
sort: _bool = ...,
1093+
group_keys: _bool = ...,
1094+
squeeze: _bool = ...,
1095+
observed: _bool = ...,
1096+
dropna: _bool = ...,
1097+
) -> DataFrameGroupBy[Any]: ...
10151098
@overload
10161099
def groupby(
10171100
self,
@@ -1024,7 +1107,7 @@ class DataFrame(NDFrame, OpsMixin):
10241107
squeeze: _bool = ...,
10251108
observed: _bool = ...,
10261109
dropna: _bool = ...,
1027-
) -> _DataFrameGroupByNonScalar: ...
1110+
) -> DataFrameGroupBy[tuple]: ...
10281111
def pivot(
10291112
self,
10301113
*,

pandas-stubs/core/groupby/generic.pyi

+4-8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ from pandas._typing import (
3030
AggFuncTypeBase,
3131
AggFuncTypeFrame,
3232
Axis,
33+
ByT,
3334
Level,
3435
ListLike,
3536
RandomState,
@@ -146,13 +147,7 @@ class SeriesGroupBy(GroupBy, Generic[S1]):
146147
def idxmax(self, axis: Axis = ..., skipna: bool = ...) -> Series: ...
147148
def idxmin(self, axis: Axis = ..., skipna: bool = ...) -> Series: ...
148149

149-
class _DataFrameGroupByScalar(DataFrameGroupBy):
150-
def __iter__(self) -> Iterator[tuple[Scalar, DataFrame]]: ...
151-
152-
class _DataFrameGroupByNonScalar(DataFrameGroupBy):
153-
def __iter__(self) -> Iterator[tuple[tuple, DataFrame]]: ...
154-
155-
class DataFrameGroupBy(GroupBy):
150+
class DataFrameGroupBy(GroupBy, Generic[ByT]):
156151
def any(self, skipna: bool = ...) -> DataFrame: ...
157152
def all(self, skipna: bool = ...) -> DataFrame: ...
158153
# error: Overload 3 for "apply" will never be used because its parameters overlap overload 1
@@ -178,7 +173,7 @@ class DataFrameGroupBy(GroupBy):
178173
@overload
179174
def __getitem__(self, item: str) -> SeriesGroupBy: ...
180175
@overload
181-
def __getitem__(self, item: list[str]) -> DataFrameGroupBy: ...
176+
def __getitem__(self, item: list[str]) -> DataFrameGroupBy[ByT]: ...
182177
def count(self) -> DataFrame: ...
183178
def boxplot(
184179
self,
@@ -364,3 +359,4 @@ class DataFrameGroupBy(GroupBy):
364359
dropna: bool = ...,
365360
) -> Series[float]: ...
366361
def __getattr__(self, name: str) -> SeriesGroupBy: ...
362+
def __iter__(self) -> Iterator[tuple[ByT, DataFrame]]: ...

tests/test_frame.py

+94
Original file line numberDiff line numberDiff line change
@@ -1972,13 +1972,107 @@ def test_groupby_result() -> None:
19721972
check(assert_type(index2, Scalar), int)
19731973
check(assert_type(value2, pd.DataFrame), pd.DataFrame)
19741974

1975+
# GH 674
1976+
# grouping by pd.MultiIndex should always resolve to a tuple as well
1977+
multi_index = pd.MultiIndex.from_frame(df[["a", "b"]])
1978+
iterator3 = df.groupby(multi_index).__iter__()
1979+
assert_type(iterator3, Iterator[Tuple[Tuple, pd.DataFrame]])
1980+
index3, value3 = next(iterator3)
1981+
assert_type((index3, value3), Tuple[Tuple, pd.DataFrame])
1982+
1983+
check(assert_type(index3, Tuple), tuple, int)
1984+
check(assert_type(value3, pd.DataFrame), pd.DataFrame)
1985+
19751986
# Want to make sure these cases are differentiated
19761987
for (k1, k2), g in df.groupby(["a", "b"]):
19771988
pass
19781989

19791990
for kk, g in df.groupby("a"):
19801991
pass
19811992

1993+
for (k1, k2), g in df.groupby(multi_index):
1994+
pass
1995+
1996+
1997+
def test_groupby_result_for_scalar_indexes() -> None:
1998+
# GH 674
1999+
dates = pd.date_range("2020-01-01", "2020-12-31")
2000+
df = pd.DataFrame({"date": dates, "days": 1})
2001+
period_index = pd.PeriodIndex(df.date, freq="M")
2002+
iterator = df.groupby(period_index).__iter__()
2003+
assert_type(iterator, Iterator[Tuple[pd.Period, pd.DataFrame]])
2004+
index, value = next(iterator)
2005+
assert_type((index, value), Tuple[pd.Period, pd.DataFrame])
2006+
2007+
check(assert_type(index, pd.Period), pd.Period)
2008+
check(assert_type(value, pd.DataFrame), pd.DataFrame)
2009+
2010+
dt_index = pd.DatetimeIndex(dates)
2011+
iterator2 = df.groupby(dt_index).__iter__()
2012+
assert_type(iterator2, Iterator[Tuple[pd.Timestamp, pd.DataFrame]])
2013+
index2, value2 = next(iterator2)
2014+
assert_type((index2, value2), Tuple[pd.Timestamp, pd.DataFrame])
2015+
2016+
check(assert_type(index2, pd.Timestamp), pd.Timestamp)
2017+
check(assert_type(value2, pd.DataFrame), pd.DataFrame)
2018+
2019+
tdelta_index = pd.TimedeltaIndex(dates - pd.Timestamp("2020-01-01"))
2020+
iterator3 = df.groupby(tdelta_index).__iter__()
2021+
assert_type(iterator3, Iterator[Tuple[pd.Timedelta, pd.DataFrame]])
2022+
index3, value3 = next(iterator3)
2023+
assert_type((index3, value3), Tuple[pd.Timedelta, pd.DataFrame])
2024+
2025+
check(assert_type(index3, pd.Timedelta), pd.Timedelta)
2026+
check(assert_type(value3, pd.DataFrame), pd.DataFrame)
2027+
2028+
intervals: list[pd.Interval[pd.Timestamp]] = [
2029+
pd.Interval(date, date + pd.DateOffset(days=1), closed="left") for date in dates
2030+
]
2031+
interval_index = pd.IntervalIndex(intervals)
2032+
assert_type(interval_index, "pd.IntervalIndex[pd.Interval[pd.Timestamp]]")
2033+
iterator4 = df.groupby(interval_index).__iter__()
2034+
assert_type(iterator4, Iterator[Tuple["pd.Interval[pd.Timestamp]", pd.DataFrame]])
2035+
index4, value4 = next(iterator4)
2036+
assert_type((index4, value4), Tuple["pd.Interval[pd.Timestamp]", pd.DataFrame])
2037+
2038+
check(assert_type(index4, "pd.Interval[pd.Timestamp]"), pd.Interval)
2039+
check(assert_type(value4, pd.DataFrame), pd.DataFrame)
2040+
2041+
for p, g in df.groupby(period_index):
2042+
pass
2043+
2044+
for dt, g in df.groupby(dt_index):
2045+
pass
2046+
2047+
for tdelta, g in df.groupby(tdelta_index):
2048+
pass
2049+
2050+
for interval, g in df.groupby(interval_index):
2051+
pass
2052+
2053+
2054+
def test_groupby_result_for_ambiguous_indexes() -> None:
2055+
# GH 674
2056+
df = pd.DataFrame({"a": [0, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]})
2057+
# this will use pd.Index which is ambiguous
2058+
iterator = df.groupby(df.index).__iter__()
2059+
assert_type(iterator, Iterator[Tuple[Any, pd.DataFrame]])
2060+
index, value = next(iterator)
2061+
assert_type((index, value), Tuple[Any, pd.DataFrame])
2062+
2063+
check(assert_type(index, Any), int)
2064+
check(assert_type(value, pd.DataFrame), pd.DataFrame)
2065+
2066+
# categorical indexes are also ambiguous
2067+
categorical_index = pd.CategoricalIndex(df.a)
2068+
iterator2 = df.groupby(categorical_index).__iter__()
2069+
assert_type(iterator2, Iterator[Tuple[Any, pd.DataFrame]])
2070+
index2, value2 = next(iterator2)
2071+
assert_type((index2, value2), Tuple[Any, pd.DataFrame])
2072+
2073+
check(assert_type(index2, Any), int)
2074+
check(assert_type(value2, pd.DataFrame), pd.DataFrame)
2075+
19822076

19832077
def test_setitem_list():
19842078
# GH 153

0 commit comments

Comments
 (0)