@@ -27,12 +27,16 @@ class SpecificationError(GroupByError):
27
27
pass
28
28
29
29
30
- def _groupby_function (name , alias , npfunc , numeric_only = True ):
30
+ def _groupby_function (name , alias , npfunc , numeric_only = True ,
31
+ _convert = False ):
31
32
def f (self ):
32
33
try :
33
34
return self ._cython_agg_general (alias , numeric_only = numeric_only )
34
35
except Exception :
35
- return self .aggregate (lambda x : npfunc (x , axis = self .axis ))
36
+ result = self .aggregate (lambda x : npfunc (x , axis = self .axis ))
37
+ if _convert :
38
+ result = result .convert_objects ()
39
+ return result
36
40
37
41
f .__doc__ = "Compute %s of group values" % name
38
42
f .__name__ = name
@@ -41,19 +45,31 @@ def f(self):
41
45
42
46
43
47
def _first_compat (x , axis = 0 ):
44
- x = np .asarray (x )
45
- x = x [com .notnull (x )]
46
- if len (x ) == 0 :
47
- return np .nan
48
- return x [0 ]
48
+ def _first (x ):
49
+ x = np .asarray (x )
50
+ x = x [com .notnull (x )]
51
+ if len (x ) == 0 :
52
+ return np .nan
53
+ return x [0 ]
54
+
55
+ if isinstance (x , DataFrame ):
56
+ return x .apply (_first , axis = axis )
57
+ else :
58
+ return _first (x )
49
59
50
60
51
61
def _last_compat (x , axis = 0 ):
52
- x = np .asarray (x )
53
- x = x [com .notnull (x )]
54
- if len (x ) == 0 :
55
- return np .nan
56
- return x [- 1 ]
62
+ def _last (x ):
63
+ x = np .asarray (x )
64
+ x = x [com .notnull (x )]
65
+ if len (x ) == 0 :
66
+ return np .nan
67
+ return x [- 1 ]
68
+
69
+ if isinstance (x , DataFrame ):
70
+ return x .apply (_last , axis = axis )
71
+ else :
72
+ return _last (x )
57
73
58
74
59
75
class GroupBy (object ):
@@ -357,8 +373,9 @@ def size(self):
357
373
min = _groupby_function ('min' , 'min' , np .min )
358
374
max = _groupby_function ('max' , 'max' , np .max )
359
375
first = _groupby_function ('first' , 'first' , _first_compat ,
360
- numeric_only = False )
361
- last = _groupby_function ('last' , 'last' , _last_compat , numeric_only = False )
376
+ numeric_only = False , _convert = True )
377
+ last = _groupby_function ('last' , 'last' , _last_compat , numeric_only = False ,
378
+ _convert = True )
362
379
363
380
def ohlc (self ):
364
381
"""
@@ -380,7 +397,7 @@ def picker(arr):
380
397
def _cython_agg_general (self , how , numeric_only = True ):
381
398
output = {}
382
399
for name , obj in self ._iterate_slices ():
383
- is_numeric = issubclass (obj .dtype . type , ( np . number , np . bool_ ) )
400
+ is_numeric = _is_numeric_dtype (obj .dtype )
384
401
if numeric_only and not is_numeric :
385
402
continue
386
403
@@ -699,12 +716,6 @@ def get_group_levels(self):
699
716
_filter_empty_groups = True
700
717
701
718
def aggregate (self , values , how , axis = 0 ):
702
- values = com .ensure_float (values )
703
- is_numeric = True
704
-
705
- if not issubclass (values .dtype .type , (np .number , np .bool_ )):
706
- values = values .astype (object )
707
- is_numeric = False
708
719
709
720
arity = self ._cython_arity .get (how , 1 )
710
721
@@ -721,6 +732,16 @@ def aggregate(self, values, how, axis=0):
721
732
raise NotImplementedError
722
733
out_shape = (self .ngroups ,) + values .shape [1 :]
723
734
735
+ if _is_numeric_dtype (values .dtype ):
736
+ values = com .ensure_float (values )
737
+ is_numeric = True
738
+ else :
739
+ if issubclass (values .dtype .type , np .datetime64 ):
740
+ raise Exception ('Cython not able to handle this case' )
741
+
742
+ values = values .astype (object )
743
+ is_numeric = False
744
+
724
745
# will be filled in Cython function
725
746
result = np .empty (out_shape , dtype = values .dtype )
726
747
counts = np .zeros (self .ngroups , dtype = np .int64 )
@@ -753,10 +774,11 @@ def aggregate(self, values, how, axis=0):
753
774
return result , names
754
775
755
776
def _aggregate (self , result , counts , values , how , is_numeric ):
756
- fdict = self ._cython_functions
757
777
if not is_numeric :
758
- fdict = self ._cython_object_functions
759
- agg_func = fdict [how ]
778
+ agg_func = self ._cython_object_functions [how ]
779
+ else :
780
+ agg_func = self ._cython_functions [how ]
781
+
760
782
trans_func = self ._cython_transforms .get (how , lambda x : x )
761
783
762
784
comp_ids , _ , ngroups = self .group_info
@@ -1458,12 +1480,15 @@ def _cython_agg_blocks(self, how, numeric_only=True):
1458
1480
1459
1481
for block in data .blocks :
1460
1482
values = block .values
1461
- is_numeric = issubclass (values .dtype .type , (np .number , np .bool_ ))
1483
+
1484
+ is_numeric = _is_numeric_dtype (values .dtype )
1485
+
1462
1486
if numeric_only and not is_numeric :
1463
1487
continue
1464
1488
1465
1489
if is_numeric :
1466
1490
values = com .ensure_float (values )
1491
+
1467
1492
result , _ = self .grouper .aggregate (values , how , axis = agg_axis )
1468
1493
newb = make_block (result , block .items , block .ref_items )
1469
1494
new_blocks .append (newb )
@@ -2231,6 +2256,12 @@ def _reorder_by_uniques(uniques, labels):
2231
2256
}
2232
2257
2233
2258
2259
+ def _is_numeric_dtype (dt ):
2260
+ typ = dt .type
2261
+ return (issubclass (typ , (np .number , np .bool_ ))
2262
+ and not issubclass (typ , (np .datetime64 , np .timedelta64 )))
2263
+
2264
+
2234
2265
def _intercept_function (func ):
2235
2266
return _func_table .get (func , func )
2236
2267
0 commit comments