40
40
from pandas .core .dtypes .missing import _maybe_fill , isna
41
41
42
42
import pandas .core .algorithms as algorithms
43
+ from pandas .core .arrays .categorical import Categorical
43
44
from pandas .core .base import SelectionMixin
44
45
import pandas .core .common as com
45
46
from pandas .core .frame import DataFrame
@@ -356,6 +357,29 @@ def get_group_levels(self) -> List[Index]:
356
357
357
358
_name_functions = {"ohlc" : ["open" , "high" , "low" , "close" ]}
358
359
360
+ _cat_method_blacklist = (
361
+ "add" ,
362
+ "median" ,
363
+ "prod" ,
364
+ "sem" ,
365
+ "cumsum" ,
366
+ "sum" ,
367
+ "cummin" ,
368
+ "mean" ,
369
+ "max" ,
370
+ "skew" ,
371
+ "cumprod" ,
372
+ "cummax" ,
373
+ "rank" ,
374
+ "pct_change" ,
375
+ "min" ,
376
+ "var" ,
377
+ "mad" ,
378
+ "describe" ,
379
+ "std" ,
380
+ "quantile" ,
381
+ )
382
+
359
383
def _is_builtin_func (self , arg ):
360
384
"""
361
385
if we define a builtin function for this argument, return it,
@@ -460,7 +484,7 @@ def _cython_operation(
460
484
461
485
# categoricals are only 1d, so we
462
486
# are not setup for dim transforming
463
- if is_categorical_dtype ( values . dtype ) or is_sparse (values .dtype ):
487
+ if is_sparse (values .dtype ):
464
488
raise NotImplementedError (f"{ values .dtype } dtype not supported" )
465
489
elif is_datetime64_any_dtype (values .dtype ):
466
490
if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
@@ -481,6 +505,7 @@ def _cython_operation(
481
505
482
506
is_datetimelike = needs_i8_conversion (values .dtype )
483
507
is_numeric = is_numeric_dtype (values .dtype )
508
+ is_categorical = is_categorical_dtype (values )
484
509
485
510
if is_datetimelike :
486
511
values = values .view ("int64" )
@@ -496,6 +521,17 @@ def _cython_operation(
496
521
values = ensure_int_or_float (values )
497
522
elif is_numeric and not is_complex_dtype (values ):
498
523
values = ensure_float64 (values )
524
+ elif is_categorical :
525
+ if how in self ._cat_method_blacklist :
526
+ raise NotImplementedError (
527
+ f"{ values .dtype } dtype not supported for `how` argument { how } "
528
+ )
529
+ values , categories , ordered = (
530
+ values .codes .astype (np .int64 ),
531
+ values .categories ,
532
+ values .ordered ,
533
+ )
534
+ is_numeric = True
499
535
else :
500
536
values = values .astype (object )
501
537
@@ -572,6 +608,11 @@ def _cython_operation(
572
608
result = type (orig_values )(result .astype (np .int64 ), dtype = orig_values .dtype )
573
609
elif is_datetimelike and kind == "aggregate" :
574
610
result = result .astype (orig_values .dtype )
611
+ elif is_categorical :
612
+ # re-create categories
613
+ result = Categorical .from_codes (
614
+ result , categories = categories , ordered = ordered ,
615
+ )
575
616
576
617
if is_extension_array_dtype (orig_values .dtype ):
577
618
result = maybe_cast_result (result = result , obj = orig_values , how = how )
0 commit comments