@@ -565,6 +565,140 @@ def test_types_groupby_methods() -> None:
565
565
check (assert_type (s .groupby (level = 0 ).idxmin (), pd .Series ), pd .Series )
566
566
567
567
568
+ def test_groupby_result () -> None :
569
+ # GH 142
570
+ # since there are no columns in a Series, groupby name only works
571
+ # with a named index, we use a MultiIndex, so we can group by more
572
+ # than one level and test the non-scalar case
573
+ multi_index = pd .MultiIndex .from_tuples ([(0 , 0 ), (0 , 1 ), (1 , 0 )], names = ["a" , "b" ])
574
+ s = pd .Series ([0 , 1 , 2 ], index = multi_index , dtype = int )
575
+ iterator = s .groupby (["a" , "b" ]).__iter__ ()
576
+ assert_type (iterator , Iterator [Tuple [Tuple , "pd.Series[int]" ]])
577
+ index , value = next (iterator )
578
+ assert_type ((index , value ), Tuple [Tuple , "pd.Series[int]" ])
579
+
580
+ check (assert_type (index , Tuple ), tuple , np .integer )
581
+ check (assert_type (value , "pd.Series[int]" ), pd .Series , np .integer )
582
+
583
+ iterator2 = s .groupby ("a" ).__iter__ ()
584
+ assert_type (iterator2 , Iterator [Tuple [Scalar , "pd.Series[int]" ]])
585
+ index2 , value2 = next (iterator2 )
586
+ assert_type ((index2 , value2 ), Tuple [Scalar , "pd.Series[int]" ])
587
+
588
+ check (assert_type (index2 , Scalar ), int )
589
+ check (assert_type (value2 , "pd.Series[int]" ), pd .Series , np .integer )
590
+
591
+ # GH 674
592
+ # grouping by pd.MultiIndex should always resolve to a tuple as well
593
+ iterator3 = s .groupby (multi_index ).__iter__ ()
594
+ assert_type (iterator3 , Iterator [Tuple [Tuple , "pd.Series[int]" ]])
595
+ index3 , value3 = next (iterator3 )
596
+ assert_type ((index3 , value3 ), Tuple [Tuple , "pd.Series[int]" ])
597
+
598
+ check (assert_type (index3 , Tuple ), tuple , int )
599
+ check (assert_type (value3 , "pd.Series[int]" ), pd .Series , np .integer )
600
+
601
+ # Want to make sure these cases are differentiated
602
+ for (k1 , k2 ), g in s .groupby (["a" , "b" ]):
603
+ pass
604
+
605
+ for kk , g in s .groupby ("a" ):
606
+ pass
607
+
608
+ for (k1 , k2 ), g in s .groupby (multi_index ):
609
+ pass
610
+
611
+
612
+ def test_groupby_result_for_scalar_indexes () -> None :
613
+ # GH 674
614
+ s = pd .Series ([0 , 1 , 2 ], dtype = int )
615
+ dates = pd .Series (
616
+ [
617
+ pd .Timestamp ("2020-01-01" ),
618
+ pd .Timestamp ("2020-01-15" ),
619
+ pd .Timestamp ("2020-02-01" ),
620
+ ],
621
+ dtype = "datetime64[ns]" ,
622
+ )
623
+
624
+ period_index = pd .PeriodIndex (dates , freq = "M" )
625
+ iterator = s .groupby (period_index ).__iter__ ()
626
+ assert_type (iterator , Iterator [Tuple [pd .Period , "pd.Series[int]" ]])
627
+ index , value = next (iterator )
628
+ assert_type ((index , value ), Tuple [pd .Period , "pd.Series[int]" ])
629
+
630
+ check (assert_type (index , pd .Period ), pd .Period )
631
+ check (assert_type (value , "pd.Series[int]" ), pd .Series , np .integer )
632
+
633
+ dt_index = pd .DatetimeIndex (dates )
634
+ iterator2 = s .groupby (dt_index ).__iter__ ()
635
+ assert_type (iterator2 , Iterator [Tuple [pd .Timestamp , "pd.Series[int]" ]])
636
+ index2 , value2 = next (iterator2 )
637
+ assert_type ((index2 , value2 ), Tuple [pd .Timestamp , "pd.Series[int]" ])
638
+
639
+ check (assert_type (index2 , pd .Timestamp ), pd .Timestamp )
640
+ check (assert_type (value2 , "pd.Series[int]" ), pd .Series , np .integer )
641
+
642
+ tdelta_index = pd .TimedeltaIndex (dates - pd .Timestamp ("2020-01-01" ))
643
+ iterator3 = s .groupby (tdelta_index ).__iter__ ()
644
+ assert_type (iterator3 , Iterator [Tuple [pd .Timedelta , "pd.Series[int]" ]])
645
+ index3 , value3 = next (iterator3 )
646
+ assert_type ((index3 , value3 ), Tuple [pd .Timedelta , "pd.Series[int]" ])
647
+
648
+ check (assert_type (index3 , pd .Timedelta ), pd .Timedelta )
649
+ check (assert_type (value3 , "pd.Series[int]" ), pd .Series , np .integer )
650
+
651
+ intervals : list [pd .Interval [pd .Timestamp ]] = [
652
+ pd .Interval (date , date + pd .DateOffset (days = 1 ), closed = "left" ) for date in dates
653
+ ]
654
+ interval_index = pd .IntervalIndex (intervals )
655
+ assert_type (interval_index , "pd.IntervalIndex[pd.Interval[pd.Timestamp]]" )
656
+ iterator4 = s .groupby (interval_index ).__iter__ ()
657
+ assert_type (
658
+ iterator4 , Iterator [Tuple ["pd.Interval[pd.Timestamp]" , "pd.Series[int]" ]]
659
+ )
660
+ index4 , value4 = next (iterator4 )
661
+ assert_type ((index4 , value4 ), Tuple ["pd.Interval[pd.Timestamp]" , "pd.Series[int]" ])
662
+
663
+ check (assert_type (index4 , "pd.Interval[pd.Timestamp]" ), pd .Interval )
664
+ check (assert_type (value4 , "pd.Series[int]" ), pd .Series , np .integer )
665
+
666
+ for p , g in s .groupby (period_index ):
667
+ pass
668
+
669
+ for dt , g in s .groupby (dt_index ):
670
+ pass
671
+
672
+ for tdelta , g in s .groupby (tdelta_index ):
673
+ pass
674
+
675
+ for interval , g in s .groupby (interval_index ):
676
+ pass
677
+
678
+
679
+ def test_groupby_result_for_ambiguous_indexes () -> None :
680
+ # GH 674
681
+ s = pd .Series ([0 , 1 , 2 ], index = ["a" , "b" , "a" ], dtype = int )
682
+ # this will use pd.Index which is ambiguous
683
+ iterator = s .groupby (s .index ).__iter__ ()
684
+ assert_type (iterator , Iterator [Tuple [Any , "pd.Series[int]" ]])
685
+ index , value = next (iterator )
686
+ assert_type ((index , value ), Tuple [Any , "pd.Series[int]" ])
687
+
688
+ check (assert_type (index , Any ), str )
689
+ check (assert_type (value , "pd.Series[int]" ), pd .Series , np .integer )
690
+
691
+ # categorical indexes are also ambiguous
692
+ categorical_index = pd .CategoricalIndex (s .index )
693
+ iterator2 = s .groupby (categorical_index ).__iter__ ()
694
+ assert_type (iterator2 , Iterator [Tuple [Any , "pd.Series[int]" ]])
695
+ index2 , value2 = next (iterator2 )
696
+ assert_type ((index2 , value2 ), Tuple [Any , "pd.Series[int]" ])
697
+
698
+ check (assert_type (index2 , Any ), str )
699
+ check (assert_type (value2 , "pd.Series[int]" ), pd .Series , np .integer )
700
+
701
+
568
702
def test_types_groupby_agg () -> None :
569
703
s = pd .Series ([4 , 2 , 1 , 8 ], index = ["a" , "b" , "a" , "b" ])
570
704
check (assert_type (s .groupby (level = 0 ).agg ("sum" ), pd .Series ), pd .Series )
0 commit comments