@@ -24,10 +24,10 @@ class DataError(GroupByError):
24
24
class SpecificationError (GroupByError ):
25
25
pass
26
26
27
- def _groupby_function (name , alias , npfunc ):
27
+ def _groupby_function (name , alias , npfunc , numeric_only = True ):
28
28
def f (self ):
29
29
try :
30
- return self ._cython_agg_general (alias )
30
+ return self ._cython_agg_general (alias , numeric_only = numeric_only )
31
31
except Exception :
32
32
return self .aggregate (lambda x : npfunc (x , axis = self .axis ))
33
33
@@ -350,8 +350,9 @@ def size(self):
350
350
prod = _groupby_function ('prod' , 'prod' , np .prod )
351
351
min = _groupby_function ('min' , 'min' , np .min )
352
352
max = _groupby_function ('max' , 'max' , np .max )
353
- first = _groupby_function ('first' , 'first' , _first_compat )
354
- last = _groupby_function ('last' , 'last' , _last_compat )
353
+ first = _groupby_function ('first' , 'first' , _first_compat ,
354
+ numeric_only = False )
355
+ last = _groupby_function ('last' , 'last' , _last_compat , numeric_only = False )
355
356
356
357
def ohlc (self ):
357
358
"""
@@ -370,10 +371,11 @@ def picker(arr):
370
371
return np .nan
371
372
return self .agg (picker )
372
373
373
- def _cython_agg_general (self , how ):
374
+ def _cython_agg_general (self , how , numeric_only = True ):
374
375
output = {}
375
376
for name , obj in self ._iterate_slices ():
376
- if not issubclass (obj .dtype .type , (np .number , np .bool_ )):
377
+ is_numeric = issubclass (obj .dtype .type , (np .number , np .bool_ ))
378
+ if numeric_only and not is_numeric :
377
379
continue
378
380
379
381
result , names = self .grouper .aggregate (obj .values , how )
@@ -668,6 +670,11 @@ def get_group_levels(self):
668
670
'last' : lib .group_last
669
671
}
670
672
673
+ _cython_object_functions = {
674
+ 'first' : lambda a , b , c , d : lib .group_nth_object (a , b , c , d , 1 ),
675
+ 'last' : lib .group_last_object
676
+ }
677
+
671
678
_cython_transforms = {
672
679
'std' : np .sqrt
673
680
}
@@ -681,7 +688,13 @@ def get_group_levels(self):
681
688
_filter_empty_groups = True
682
689
683
690
def aggregate (self , values , how , axis = 0 ):
684
- values = com ._ensure_float64 (values )
691
+ values = com .ensure_float (values )
692
+ is_numeric = True
693
+
694
+ if not issubclass (values .dtype .type , (np .number , np .bool_ )):
695
+ values = values .astype (object )
696
+ is_numeric = False
697
+
685
698
arity = self ._cython_arity .get (how , 1 )
686
699
687
700
vdim = values .ndim
@@ -698,15 +711,19 @@ def aggregate(self, values, how, axis=0):
698
711
out_shape = (self .ngroups ,) + values .shape [1 :]
699
712
700
713
# will be filled in Cython function
701
- result = np .empty (out_shape , dtype = np . float64 )
714
+ result = np .empty (out_shape , dtype = values . dtype )
702
715
counts = np .zeros (self .ngroups , dtype = np .int64 )
703
716
704
- result = self ._aggregate (result , counts , values , how )
717
+ result = self ._aggregate (result , counts , values , how , is_numeric )
705
718
706
719
if self ._filter_empty_groups :
707
720
if result .ndim == 2 :
708
- result = lib .row_bool_subset (result ,
709
- (counts > 0 ).view (np .uint8 ))
721
+ if is_numeric :
722
+ result = lib .row_bool_subset (result ,
723
+ (counts > 0 ).view (np .uint8 ))
724
+ else :
725
+ result = lib .row_bool_subset_object (result ,
726
+ (counts > 0 ).view (np .uint8 ))
710
727
else :
711
728
result = result [counts > 0 ]
712
729
@@ -724,8 +741,11 @@ def aggregate(self, values, how, axis=0):
724
741
725
742
return result , names
726
743
727
- def _aggregate (self , result , counts , values , how ):
728
- agg_func = self ._cython_functions [how ]
744
+ def _aggregate (self , result , counts , values , how , is_numeric ):
745
+ fdict = self ._cython_functions
746
+ if not is_numeric :
747
+ fdict = self ._cython_object_functions
748
+ agg_func = fdict [how ]
729
749
trans_func = self ._cython_transforms .get (how , lambda x : x )
730
750
731
751
comp_ids , _ , ngroups = self .group_info
@@ -913,14 +933,22 @@ def names(self):
913
933
'last' : lib .group_last_bin
914
934
}
915
935
936
+ _cython_object_functions = {
937
+ 'first' : lambda a , b , c , d : lib .group_nth_bin_object (a , b , c , d , 1 ),
938
+ 'last' : lib .group_last_bin_object
939
+ }
940
+
916
941
_name_functions = {
917
942
'ohlc' : lambda * args : ['open' , 'high' , 'low' , 'close' ]
918
943
}
919
944
920
945
_filter_empty_groups = True
921
946
922
- def _aggregate (self , result , counts , values , how ):
923
- agg_func = self ._cython_functions [how ]
947
+ def _aggregate (self , result , counts , values , how , is_numeric = True ):
948
+ fdict = self ._cython_functions
949
+ if not is_numeric :
950
+ fdict = self ._cython_object_functions
951
+ agg_func = fdict [how ]
924
952
trans_func = self ._cython_transforms .get (how , lambda x : x )
925
953
926
954
if values .ndim > 3 :
@@ -1385,8 +1413,8 @@ def _iterate_slices(self):
1385
1413
1386
1414
yield val , slicer (val )
1387
1415
1388
- def _cython_agg_general (self , how ):
1389
- new_blocks = self ._cython_agg_blocks (how )
1416
+ def _cython_agg_general (self , how , numeric_only = True ):
1417
+ new_blocks = self ._cython_agg_blocks (how , numeric_only = numeric_only )
1390
1418
return self ._wrap_agged_blocks (new_blocks )
1391
1419
1392
1420
def _wrap_agged_blocks (self , blocks ):
@@ -1408,18 +1436,20 @@ def _wrap_agged_blocks(self, blocks):
1408
1436
1409
1437
_block_agg_axis = 0
1410
1438
1411
- def _cython_agg_blocks (self , how ):
1439
+ def _cython_agg_blocks (self , how , numeric_only = True ):
1412
1440
data , agg_axis = self ._get_data_to_aggregate ()
1413
1441
1414
1442
new_blocks = []
1415
1443
1416
1444
for block in data .blocks :
1417
1445
values = block .values
1418
- if not issubclass (values .dtype .type , (np .number , np .bool_ )):
1446
+ is_numeric = issubclass (values .dtype .type , (np .number , np .bool_ ))
1447
+ if numeric_only and not is_numeric :
1419
1448
continue
1420
1449
1421
- values = com ._ensure_float64 (values )
1422
- result , names = self .grouper .aggregate (values , how , axis = agg_axis )
1450
+ if is_numeric :
1451
+ values = com .ensure_float (values )
1452
+ result , _ = self .grouper .aggregate (values , how , axis = agg_axis )
1423
1453
newb = make_block (result , block .items , block .ref_items )
1424
1454
new_blocks .append (newb )
1425
1455
@@ -2210,5 +2240,3 @@ def complete_dataframe(obj, prev_completions):
2210
2240
install_ipython_completers ()
2211
2241
except Exception :
2212
2242
pass
2213
-
2214
-
0 commit comments