@@ -413,22 +413,44 @@ def test_mixed_groupings(normalize, expected_label, expected_values):
413
413
],
414
414
)
415
415
@pytest .mark .parametrize ("as_index" , [False , True ])
416
- def test_column_name_clashes (test , columns , expected_names , as_index ):
416
+ def test_column_label_duplicates (test , columns , expected_names , as_index ):
417
+ # Test for duplicate input column labels and generated duplicate labels
417
418
df = DataFrame ([[1 , 3 , 5 , 7 , 9 ], [2 , 4 , 6 , 8 , 10 ]], columns = columns )
418
-
419
+ expected_data = [(1 , 0 , 7 , 3 , 5 , 9 ), (2 , 1 , 8 , 4 , 6 , 10 )]
420
+ result = df .groupby (["a" , [0 , 1 ], "d" ], as_index = as_index ).value_counts ()
419
421
if as_index :
420
- result = df .groupby (["a" , [0 , 1 ], "d" ], as_index = as_index ).value_counts ()
421
422
expected = Series (
422
423
data = (1 , 1 ),
423
424
index = MultiIndex .from_tuples (
424
- [( 1 , 0 , 7 , 3 , 5 , 9 ), ( 2 , 1 , 8 , 4 , 6 , 10 )] ,
425
+ expected_data ,
425
426
names = expected_names ,
426
427
),
427
428
)
428
429
tm .assert_series_equal (result , expected )
429
430
else :
430
- with pytest .raises (ValueError , match = "cannot insert" ):
431
- df .groupby (["a" , [0 , 1 ], "d" ], as_index = as_index ).value_counts ()
431
+ expected_data = [list (row ) + [1 ] for row in expected_data ]
432
+ expected_columns = list (expected_names )
433
+ expected_columns [1 ] = "level_1"
434
+ expected_columns .append ("count" )
435
+ expected = DataFrame (expected_data , columns = expected_columns )
436
+ tm .assert_frame_equal (result , expected )
437
+
438
+
439
+ @pytest .mark .parametrize (
440
+ "normalize, expected_label" ,
441
+ [
442
+ (False , "count" ),
443
+ (True , "proportion" ),
444
+ ],
445
+ )
446
+ def test_result_label_duplicates (normalize , expected_label ):
447
+ # Test for result column label duplicating an input column label
448
+ gb = DataFrame ([[1 , 2 , 3 ]], columns = ["a" , "b" , expected_label ]).groupby (
449
+ "a" , as_index = False
450
+ )
451
+ msg = f"Column label '{ expected_label } ' is duplicate of result column"
452
+ with pytest .raises (ValueError , match = msg ):
453
+ gb .value_counts (normalize = normalize )
432
454
433
455
434
456
def test_ambiguous_grouping ():
0 commit comments