33
33
from typing_extensions import assert_type
34
34
import xarray as xr
35
35
36
- from pandas ._typing import Scalar
36
+ from pandas ._typing import (
37
+ AggFuncTypeBase ,
38
+ Scalar ,
39
+ )
37
40
38
41
from tests import (
39
42
TYPE_CHECKING_INVALID_USAGE ,
@@ -643,7 +646,9 @@ def test_types_groupby_methods() -> None:
643
646
644
647
645
648
def test_types_groupby_agg () -> None :
646
- df = pd .DataFrame (data = {"col1" : [1 , 1 , 2 ], "col2" : [3 , 4 , 5 ], "col3" : [0 , 1 , 0 ]})
649
+ df = pd .DataFrame (
650
+ data = {"col1" : [1 , 1 , 2 ], "col2" : [3 , 4 , 5 ], "col3" : [0 , 1 , 0 ], 0 : [- 1 , - 1 , - 1 ]}
651
+ )
647
652
check (assert_type (df .groupby ("col1" )["col3" ].agg (min ), pd .Series ), pd .Series )
648
653
check (
649
654
assert_type (df .groupby ("col1" )["col3" ].agg ([min , max ]), pd .DataFrame ),
@@ -655,21 +660,19 @@ def test_types_groupby_agg() -> None:
655
660
assert_type (df .groupby ("col1" ).agg (["min" , "max" ]), pd .DataFrame ), pd .DataFrame
656
661
)
657
662
check (assert_type (df .groupby ("col1" ).agg ([min , max ]), pd .DataFrame ), pd .DataFrame )
663
+ agg_dict1 : dict [Hashable , str ] = {"col2" : "min" , "col3" : "max" , 0 : "avg" }
664
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict1 ), pd .DataFrame ), pd .DataFrame )
665
+ agg_dict2 : dict [Hashable , AggFuncTypeBase ] = {"col2" : min , "col3" : max , 0 : min }
666
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict2 ), pd .DataFrame ), pd .DataFrame )
667
+ agg_dict3 : dict [Hashable , str | AggFuncTypeBase ] = {
668
+ "col2" : min ,
669
+ "col3" : "max" ,
670
+ 0 : lambda x : x .min (),
671
+ }
672
+ check (assert_type (df .groupby ("col1" ).agg (agg_dict3 ), pd .DataFrame ), pd .DataFrame )
673
+ named_agg = pd .NamedAgg (column = "col2" , aggfunc = "max" )
658
674
check (
659
- assert_type (
660
- df .groupby ("col1" ).agg ({"col2" : "min" , "col3" : "max" }), pd .DataFrame
661
- ),
662
- pd .DataFrame ,
663
- )
664
- check (
665
- assert_type (df .groupby ("col1" ).agg ({"col2" : min , "col3" : max }), pd .DataFrame ),
666
- pd .DataFrame ,
667
- )
668
- check (
669
- assert_type (
670
- df .groupby ("col1" ).agg (new_col = pd .NamedAgg (column = "col2" , aggfunc = "max" )),
671
- pd .DataFrame ,
672
- ),
675
+ assert_type (df .groupby ("col1" ).agg (new_col = named_agg ), pd .DataFrame ),
673
676
pd .DataFrame ,
674
677
)
675
678
# GH#187
@@ -679,6 +682,9 @@ def test_types_groupby_agg() -> None:
679
682
cols_opt : list [str | None ] = ["col1" , "col2" ]
680
683
check (assert_type (df .groupby (by = cols_opt ).sum (), pd .DataFrame ), pd .DataFrame )
681
684
685
+ cols_mixed : list [str | int ] = ["col1" , 0 ]
686
+ check (assert_type (df .groupby (by = cols_mixed ).sum (), pd .DataFrame ), pd .DataFrame )
687
+
682
688
683
689
# This was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
684
690
def test_types_group_by_with_dropna_keyword () -> None :
0 commit comments