@@ -495,6 +495,80 @@ def test_mangled(self):
495
495
tm .assert_frame_equal (result , expected )
496
496
497
497
498
+ @pytest .mark .parametrize (
499
+ "agg_col1, agg_col2, agg_col3, agg_result1, agg_result2, agg_result3" ,
500
+ [
501
+ (
502
+ (("y" , "A" ), "max" ),
503
+ (("y" , "A" ), np .min ),
504
+ (("y" , "B" ), "mean" ),
505
+ [1 , 3 ],
506
+ [0 , 2 ],
507
+ [5.5 , 7.5 ],
508
+ ),
509
+ (
510
+ (("y" , "A" ), lambda x : max (x )),
511
+ (("y" , "A" ), lambda x : 1 ),
512
+ (("y" , "B" ), "mean" ),
513
+ [1 , 3 ],
514
+ [1 , 1 ],
515
+ [5.5 , 7.5 ],
516
+ ),
517
+ (
518
+ pd .NamedAgg (("y" , "A" ), "max" ),
519
+ pd .NamedAgg (("y" , "B" ), np .mean ),
520
+ pd .NamedAgg (("y" , "A" ), lambda x : 1 ),
521
+ [1 , 3 ],
522
+ [5.5 , 7.5 ],
523
+ [1 , 1 ],
524
+ ),
525
+ ],
526
+ )
527
+ def test_agg_relabel_multiindex_column (
528
+ agg_col1 , agg_col2 , agg_col3 , agg_result1 , agg_result2 , agg_result3
529
+ ):
530
+ # GH 29422, add tests for multiindex column cases
531
+ df = DataFrame (
532
+ {"group" : ["a" , "a" , "b" , "b" ], "A" : [0 , 1 , 2 , 3 ], "B" : [5 , 6 , 7 , 8 ]}
533
+ )
534
+ df .columns = pd .MultiIndex .from_tuples ([("x" , "group" ), ("y" , "A" ), ("y" , "B" )])
535
+ idx = pd .Index (["a" , "b" ], name = ("x" , "group" ))
536
+
537
+ result = df .groupby (("x" , "group" )).agg (a_max = (("y" , "A" ), "max" ))
538
+ expected = DataFrame ({"a_max" : [1 , 3 ]}, index = idx )
539
+ tm .assert_frame_equal (result , expected )
540
+
541
+ result = df .groupby (("x" , "group" )).agg (
542
+ col_1 = agg_col1 , col_2 = agg_col2 , col_3 = agg_col3
543
+ )
544
+ expected = DataFrame (
545
+ {"col_1" : agg_result1 , "col_2" : agg_result2 , "col_3" : agg_result3 }, index = idx
546
+ )
547
+ tm .assert_frame_equal (result , expected )
548
+
549
+
550
+ def test_agg_relabel_multiindex_raises_not_exist ():
551
+ # GH 29422, add test for raises senario when aggregate column does not exist
552
+ df = DataFrame (
553
+ {"group" : ["a" , "a" , "b" , "b" ], "A" : [0 , 1 , 2 , 3 ], "B" : [5 , 6 , 7 , 8 ]}
554
+ )
555
+ df .columns = pd .MultiIndex .from_tuples ([("x" , "group" ), ("y" , "A" ), ("y" , "B" )])
556
+
557
+ with pytest .raises (KeyError , match = "does not exist" ):
558
+ df .groupby (("x" , "group" )).agg (a = (("Y" , "a" ), "max" ))
559
+
560
+
561
+ def test_agg_relabel_multiindex_raises_duplicate ():
562
+ # GH29422, add test for raises senario when getting duplicates
563
+ df = DataFrame (
564
+ {"group" : ["a" , "a" , "b" , "b" ], "A" : [0 , 1 , 2 , 3 ], "B" : [5 , 6 , 7 , 8 ]}
565
+ )
566
+ df .columns = pd .MultiIndex .from_tuples ([("x" , "group" ), ("y" , "A" ), ("y" , "B" )])
567
+
568
+ with pytest .raises (SpecificationError , match = "Function names" ):
569
+ df .groupby (("x" , "group" )).agg (a = (("y" , "A" ), "min" ), b = (("y" , "A" ), "min" ))
570
+
571
+
498
572
def myfunc (s ):
499
573
return np .percentile (s , q = 0.90 )
500
574
0 commit comments