Skip to content

Commit 068679c

Browse files
authored
Allow sequence in groupby level (#837)
* Add test cases for level sequences This test fail currently, as the level parameter currently does not accept any sequences. * Allow sequences for groupby level parameter This fixes #836 * Add assert_type * Remove unnecessary quotes
1 parent c0ca527 commit 068679c

File tree

4 files changed

+45
-16
lines changed

4 files changed

+45
-16
lines changed

pandas-stubs/core/frame.pyi

+8-8
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ class DataFrame(NDFrame, OpsMixin):
10091009
self,
10101010
by: Scalar,
10111011
axis: AxisIndex = ...,
1012-
level: Level | None = ...,
1012+
level: IndexLabel | None = ...,
10131013
as_index: _bool = ...,
10141014
sort: _bool = ...,
10151015
group_keys: _bool = ...,
@@ -1022,7 +1022,7 @@ class DataFrame(NDFrame, OpsMixin):
10221022
self,
10231023
by: DatetimeIndex,
10241024
axis: AxisIndex = ...,
1025-
level: Level | None = ...,
1025+
level: IndexLabel | None = ...,
10261026
as_index: _bool = ...,
10271027
sort: _bool = ...,
10281028
group_keys: _bool = ...,
@@ -1035,7 +1035,7 @@ class DataFrame(NDFrame, OpsMixin):
10351035
self,
10361036
by: TimedeltaIndex,
10371037
axis: AxisIndex = ...,
1038-
level: Level | None = ...,
1038+
level: IndexLabel | None = ...,
10391039
as_index: _bool = ...,
10401040
sort: _bool = ...,
10411041
group_keys: _bool = ...,
@@ -1048,7 +1048,7 @@ class DataFrame(NDFrame, OpsMixin):
10481048
self,
10491049
by: PeriodIndex,
10501050
axis: AxisIndex = ...,
1051-
level: Level | None = ...,
1051+
level: IndexLabel | None = ...,
10521052
as_index: _bool = ...,
10531053
sort: _bool = ...,
10541054
group_keys: _bool = ...,
@@ -1061,7 +1061,7 @@ class DataFrame(NDFrame, OpsMixin):
10611061
self,
10621062
by: IntervalIndex[IntervalT],
10631063
axis: AxisIndex = ...,
1064-
level: Level | None = ...,
1064+
level: IndexLabel | None = ...,
10651065
as_index: _bool = ...,
10661066
sort: _bool = ...,
10671067
group_keys: _bool = ...,
@@ -1074,7 +1074,7 @@ class DataFrame(NDFrame, OpsMixin):
10741074
self,
10751075
by: MultiIndex | GroupByObjectNonScalar | None = ...,
10761076
axis: AxisIndex = ...,
1077-
level: Level | None = ...,
1077+
level: IndexLabel | None = ...,
10781078
as_index: _bool = ...,
10791079
sort: _bool = ...,
10801080
group_keys: _bool = ...,
@@ -1087,7 +1087,7 @@ class DataFrame(NDFrame, OpsMixin):
10871087
self,
10881088
by: Series[SeriesByT],
10891089
axis: AxisIndex = ...,
1090-
level: Level | None = ...,
1090+
level: IndexLabel | None = ...,
10911091
as_index: _bool = ...,
10921092
sort: _bool = ...,
10931093
group_keys: _bool = ...,
@@ -1100,7 +1100,7 @@ class DataFrame(NDFrame, OpsMixin):
11001100
self,
11011101
by: CategoricalIndex | Index | Series,
11021102
axis: AxisIndex = ...,
1103-
level: Level | None = ...,
1103+
level: IndexLabel | None = ...,
11041104
as_index: _bool = ...,
11051105
sort: _bool = ...,
11061106
group_keys: _bool = ...,

pandas-stubs/core/series.pyi

+9-8
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ from pandas._typing import (
117117
HashableT3,
118118
IgnoreRaise,
119119
IndexingInt,
120+
IndexLabel,
120121
IntDtypeArg,
121122
InterpolateOptions,
122123
IntervalClosedType,
@@ -547,7 +548,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
547548
self,
548549
by: Scalar,
549550
axis: AxisIndex = ...,
550-
level: Level | None = ...,
551+
level: IndexLabel | None = ...,
551552
as_index: _bool = ...,
552553
sort: _bool = ...,
553554
group_keys: _bool = ...,
@@ -560,7 +561,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
560561
self,
561562
by: DatetimeIndex,
562563
axis: AxisIndex = ...,
563-
level: Level | None = ...,
564+
level: IndexLabel | None = ...,
564565
as_index: _bool = ...,
565566
sort: _bool = ...,
566567
group_keys: _bool = ...,
@@ -573,7 +574,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
573574
self,
574575
by: TimedeltaIndex,
575576
axis: AxisIndex = ...,
576-
level: Level | None = ...,
577+
level: IndexLabel | None = ...,
577578
as_index: _bool = ...,
578579
sort: _bool = ...,
579580
group_keys: _bool = ...,
@@ -586,7 +587,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
586587
self,
587588
by: PeriodIndex,
588589
axis: AxisIndex = ...,
589-
level: Level | None = ...,
590+
level: IndexLabel | None = ...,
590591
as_index: _bool = ...,
591592
sort: _bool = ...,
592593
group_keys: _bool = ...,
@@ -599,7 +600,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
599600
self,
600601
by: IntervalIndex[IntervalT],
601602
axis: AxisIndex = ...,
602-
level: Level | None = ...,
603+
level: IndexLabel | None = ...,
603604
as_index: _bool = ...,
604605
sort: _bool = ...,
605606
group_keys: _bool = ...,
@@ -612,7 +613,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
612613
self,
613614
by: MultiIndex | GroupByObjectNonScalar = ...,
614615
axis: AxisIndex = ...,
615-
level: Level | None = ...,
616+
level: IndexLabel | None = ...,
616617
as_index: _bool = ...,
617618
sort: _bool = ...,
618619
group_keys: _bool = ...,
@@ -625,7 +626,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
625626
self,
626627
by: Series[SeriesByT],
627628
axis: AxisIndex = ...,
628-
level: Level | None = ...,
629+
level: IndexLabel | None = ...,
629630
as_index: _bool = ...,
630631
sort: _bool = ...,
631632
group_keys: _bool = ...,
@@ -638,7 +639,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
638639
self,
639640
by: CategoricalIndex | Index | Series,
640641
axis: AxisIndex = ...,
641-
level: Level | None = ...,
642+
level: IndexLabel | None = ...,
642643
as_index: _bool = ...,
643644
sort: _bool = ...,
644645
group_keys: _bool = ...,

tests/test_frame.py

+15
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,21 @@ def test_types_groupby_iter() -> None:
10731073
)
10741074

10751075

1076+
def test_types_groupby_level() -> None:
1077+
# GH 836
1078+
data = {
1079+
"col1": [0, 0, 0],
1080+
"col2": [0, 1, 0],
1081+
"col3": [1, 2, 3],
1082+
"col4": [1, 2, 3],
1083+
}
1084+
df = pd.DataFrame(data=data).set_index(["col1", "col2", "col3"])
1085+
check(
1086+
assert_type(df.groupby(level=["col1", "col2"]).sum(), pd.DataFrame),
1087+
pd.DataFrame,
1088+
)
1089+
1090+
10761091
def test_types_merge() -> None:
10771092
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})
10781093
df2 = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [0, 1, 0]})

tests/test_series.py

+13
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,19 @@ def test_types_max() -> None:
439439
s.max(skipna=False)
440440

441441

442+
def test_types_groupby_level() -> None:
443+
# GH 836
444+
index = pd.MultiIndex.from_tuples(
445+
[(0, 0, 1), (0, 1, 2), (0, 0, 3)], names=["col1", "col2", "col3"]
446+
)
447+
s = pd.Series([1, 2, 3], index=index)
448+
check(
449+
assert_type(s.groupby(level=["col1", "col2"]).sum(), "pd.Series[int]"),
450+
pd.Series,
451+
np.integer,
452+
)
453+
454+
442455
def test_types_quantile() -> None:
443456
s = pd.Series([1, 2, 3, 10])
444457
s.quantile([0.25, 0.5])

0 commit comments

Comments
 (0)