@@ -691,14 +691,25 @@ def test_agg_relabel_multiindex_duplicates():
691
691
tm .assert_frame_equal (result , expected )
692
692
693
693
694
- def test_multiindex_custom_func ():
694
+ @pytest .mark .parametrize (
695
+ "func, expected_values" ,
696
+ [
697
+ (lambda s : s .mean (), [[3 , 2 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
698
+ (np .mean , [[3.0 , 2.0 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
699
+ (np .nanmean , [[3.0 , 2.0 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
700
+ ],
701
+ )
702
+ def test_multiindex_custom_func (func , expected_values ):
695
703
# GH 31777
696
- df = pd .DataFrame (
697
- np .random .rand (10 , 4 ), columns = pd .MultiIndex .from_product ([[1 , 2 ], [3 , 4 ]])
704
+ data = [[1 , 4 , 2 , 8 ], [5 , 7 , 1 , 4 ], [2 , 8 , 1 , 4 ], [2 , 8 , 5 , 7 ]]
705
+ df = pd .DataFrame (data , columns = pd .MultiIndex .from_product ([[1 , 2 ], [3 , 4 ]]))
706
+ grp = df .groupby (np .r_ [np .zeros (2 ), np .ones (2 )])
707
+ result = grp .agg (func )
708
+ expected_keys = [(1 , 3 ), (1 , 4 ), (2 , 3 ), (2 , 4 )]
709
+ expected = pd .DataFrame (
710
+ {key : value for key , value in zip (expected_keys , expected_values )},
711
+ index = Index ([0.0 , 1.0 ], dtype = float ),
698
712
)
699
- grp = df .groupby (np .r_ [np .ones (5 ), np .zeros (5 )])
700
- result = grp .agg (lambda s : s .mean ())
701
- expected = grp .agg ("mean" )
702
713
tm .assert_frame_equal (result , expected )
703
714
704
715
0 commit comments