@@ -341,7 +341,7 @@ def test_aggregate_numbers(self):
341
341
),
342
342
)
343
343
344
- def test_aggregate_strings (self ):
344
+ def test_aggregate_text_values (self ):
345
345
result = groupby (
346
346
pd .DataFrame ({"A" : [1 , 1 , 1 ], "B" : ["a" , "b" , "a" ]}),
347
347
[Group ("A" , None )],
@@ -367,6 +367,60 @@ def test_aggregate_strings(self):
367
367
),
368
368
)
369
369
370
+ def test_aggregate_text_category_values (self ):
371
+ result = groupby (
372
+ pd .DataFrame (
373
+ {"A" : [1 , 1 , 1 ], "B" : pd .Series (["a" , "b" , "a" ], dtype = "category" )}
374
+ ),
375
+ [Group ("A" , None )],
376
+ [
377
+ Aggregation (Operation .SIZE , "B" , "size" ),
378
+ Aggregation (Operation .NUNIQUE , "B" , "nunique" ),
379
+ Aggregation (Operation .MIN , "B" , "min" ),
380
+ Aggregation (Operation .MAX , "B" , "max" ),
381
+ Aggregation (Operation .FIRST , "B" , "first" ),
382
+ ],
383
+ )
384
+ assert_frame_equal (
385
+ result ,
386
+ pd .DataFrame (
387
+ {
388
+ "A" : [1 ],
389
+ "size" : [3 ],
390
+ "nunique" : [2 ],
391
+ "min" : pd .Series (["a" ], dtype = "category" ),
392
+ "max" : pd .Series (["b" ], dtype = "category" ),
393
+ "first" : pd .Series (["a" ], dtype = "category" ),
394
+ }
395
+ ),
396
+ )
397
+
398
+ def test_aggregate_text_category_values_empty_still_has_object_dtype (self ):
399
+ result = groupby (
400
+ pd .DataFrame ({"A" : [None ]}, dtype = str ).astype ("category" ),
401
+ [Group ("A" , None )],
402
+ [
403
+ Aggregation (Operation .SIZE , "A" , "size" ),
404
+ Aggregation (Operation .NUNIQUE , "A" , "nunique" ),
405
+ Aggregation (Operation .MIN , "A" , "min" ),
406
+ Aggregation (Operation .MAX , "A" , "max" ),
407
+ Aggregation (Operation .FIRST , "A" , "first" ),
408
+ ],
409
+ )
410
+ assert_frame_equal (
411
+ result ,
412
+ pd .DataFrame (
413
+ {
414
+ "A" : pd .Series ([], dtype = str ).astype ("category" ),
415
+ "size" : pd .Series ([], dtype = int ),
416
+ "nunique" : pd .Series ([], dtype = int ),
417
+ "min" : pd .Series ([], dtype = str ).astype ("category" ),
418
+ "max" : pd .Series ([], dtype = str ).astype ("category" ),
419
+ "first" : pd .Series ([], dtype = str ).astype ("category" ),
420
+ }
421
+ ),
422
+ )
423
+
370
424
def test_aggregate_datetime_no_granularity (self ):
371
425
result = groupby (
372
426
pd .DataFrame ({"A" : [dt (2018 , 1 , 4 ), dt (2018 , 1 , 5 ), dt (2018 , 1 , 4 )]}),
0 commit comments