@@ -643,7 +643,9 @@ def test_types_groupby_methods() -> None:
643
643
644
644
645
645
def test_types_groupby_agg () -> None :
646
- df = pd .DataFrame (data = {"col1" : [1 , 1 , 2 ], "col2" : [3 , 4 , 5 ], "col3" : [0 , 1 , 0 ]})
646
+ df = pd .DataFrame (
647
+ data = {"col1" : [1 , 1 , 2 ], "col2" : [3 , 4 , 5 ], "col3" : [0 , 1 , 0 ], 0 : [- 1 , - 1 , - 1 ]}
648
+ )
647
649
check (assert_type (df .groupby ("col1" )["col3" ].agg (min ), pd .Series ), pd .Series )
648
650
check (
649
651
assert_type (df .groupby ("col1" )["col3" ].agg ([min , max ]), pd .DataFrame ),
@@ -655,21 +657,28 @@ def test_types_groupby_agg() -> None:
655
657
assert_type (df .groupby ("col1" ).agg (["min" , "max" ]), pd .DataFrame ), pd .DataFrame
656
658
)
657
659
check (assert_type (df .groupby ("col1" ).agg ([min , max ]), pd .DataFrame ), pd .DataFrame )
660
+ agg_dict1 = {"col2" : "min" , "col3" : "max" , 0 : "sum" }
661
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict1 ), pd .DataFrame ), pd .DataFrame )
662
+ agg_dict2 = {"col2" : min , "col3" : max , 0 : min }
663
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict2 ), pd .DataFrame ), pd .DataFrame )
664
+
665
+ def wrapped_min (x : Any ) -> Any :
666
+ return x .min ()
667
+
668
+ # Here, MyPy infers dict[object, object], so it must be explicitly annotated
669
+ agg_dict3 : dict [str | int , str | Callable ] = {
670
+ "col2" : min ,
671
+ "col3" : "max" ,
672
+ 0 : wrapped_min ,
673
+ }
674
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict3 ), pd .DataFrame ), pd .DataFrame )
675
+ agg_dict4 = {"col2" : "sum" }
676
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict4 ), pd .DataFrame ), pd .DataFrame )
677
+ agg_dict5 = {0 : "sum" }
678
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict5 ), pd .DataFrame ), pd .DataFrame )
679
+ named_agg = pd .NamedAgg (column = "col2" , aggfunc = "max" )
658
680
check (
659
- assert_type (
660
- df .groupby ("col1" ).agg ({"col2" : "min" , "col3" : "max" }), pd .DataFrame
661
- ),
662
- pd .DataFrame ,
663
- )
664
- check (
665
- assert_type (df .groupby ("col1" ).agg ({"col2" : min , "col3" : max }), pd .DataFrame ),
666
- pd .DataFrame ,
667
- )
668
- check (
669
- assert_type (
670
- df .groupby ("col1" ).agg (new_col = pd .NamedAgg (column = "col2" , aggfunc = "max" )),
671
- pd .DataFrame ,
672
- ),
681
+ assert_type (df .groupby ("col1" ).agg (new_col = named_agg ), pd .DataFrame ),
673
682
pd .DataFrame ,
674
683
)
675
684
# GH#187
@@ -679,6 +688,9 @@ def test_types_groupby_agg() -> None:
679
688
cols_opt : list [str | None ] = ["col1" , "col2" ]
680
689
check (assert_type (df .groupby (by = cols_opt ).sum (), pd .DataFrame ), pd .DataFrame )
681
690
691
+ cols_mixed : list [str | int ] = ["col1" , 0 ]
692
+ check (assert_type (df .groupby (by = cols_mixed ).sum (), pd .DataFrame ), pd .DataFrame )
693
+
682
694
683
695
# This was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
684
696
def test_types_group_by_with_dropna_keyword () -> None :
0 commit comments