@@ -453,28 +453,28 @@ def test_dataframe_categorical_with_nan(observed):
453
453
tm .assert_frame_equal (result , expected )
454
454
455
455
456
+ @pytest .mark .parametrize ("ordered" , [True , False ])
456
457
@pytest .mark .parametrize ("observed" , [True , False ])
457
458
@pytest .mark .parametrize ("sort" , [True , False ])
458
- def test_dataframe_categorical_ordered_observed (observed , sort ):
459
- # GH 25871
460
- cat = pd .Categorical ([3 , 1 , 2 , 1 , 3 , 2 ], categories = [1 , 2 , 3 , 4 ], ordered = True )
461
- val = pd .Series ([1.5 , 0.5 , 1.0 , 0.5 , 1.5 , 1.0 ])
459
+ def test_dataframe_categorical_ordered_observed_sort (ordered , observed , sort ):
460
+ # GH 25871: Fix groupby sorting on ordered Categoricals
461
+ # Build a dataframe with a Categorical having one unobserved category ('AWOL'), and a Series with identical values
462
+ cat = pd .Categorical (['d' , 'a' , 'b' , 'a' , 'd' , 'b' ], categories = ['a' , 'b' , 'AWOL' , 'd' ], ordered = ordered )
463
+ val = pd .Series (['d' , 'a' , 'b' , 'a' , 'd' , 'b' ])
462
464
df = pd .DataFrame ({'cat' : cat , 'val' : val })
463
- result = df .groupby ('cat' , observed = observed , sort = sort )['val' ].agg ('sum' )
464
-
465
- # For ordered Categoricals, sort must have no influence on the result (they always sort)
466
- if observed :
467
- expected = pd .Series (data = [1.0 , 2.0 , 3.0 ],
468
- index = pd .CategoricalIndex ([1 , 2 , 3 ], categories = [1 , 2 , 3 , 4 ], ordered = True , name = 'cat' , dtype = 'category' ),
469
- dtype = 'float64' , name = 'val' )
470
- else :
471
- expected = pd .Series (data = [1.0 , 2.0 , 3.0 , 0.0 ],
472
- index = pd .CategoricalIndex ([1 , 2 , 3 , 4 ], categories = [1 , 2 , 3 , 4 ], ordered = True , name = 'cat' , dtype = 'category' ),
473
- dtype = 'float64' , name = 'val' )
474
-
475
- tm .assert_series_equal (result , expected )
476
465
466
+ # aggregate on the Categorical
467
+ result = df .groupby ('cat' , observed = observed , sort = sort )['val' ].agg ('first' )
468
+
469
+ # If ordering is correct, we expect index labels equal to aggregation results,
470
+ # except for 'observed=False', when index contains 'AWOL' and aggregation None
471
+ label = pd .Series (result .index .array , dtype = 'object' )
472
+ aggr = pd .Series (result .array )
473
+ if not observed :
474
+ aggr [aggr .isna ()] = 'AWOL'
475
+ tm .assert_equal (label , aggr )
477
476
477
+
478
478
def test_datetime ():
479
479
# GH9049: ensure backward compatibility
480
480
levels = pd .date_range ('2014-01-01' , periods = 4 )
0 commit comments