Skip to content

Commit 8fc8d6f

Browse files
committed
Merge remote-tracking branch 'chang/groupby-last'
* chang/groupby-last: cython methods for group bins #1809 BUG: allow non-numeric columns in groupby first/last #1809
2 parents 163cc8a + d0c9957 commit 8fc8d6f

File tree

4 files changed

+254
-27
lines changed

4 files changed

+254
-27
lines changed

pandas/core/groupby.py

+51-23
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ class DataError(GroupByError):
2424
class SpecificationError(GroupByError):
2525
pass
2626

27-
def _groupby_function(name, alias, npfunc):
27+
def _groupby_function(name, alias, npfunc, numeric_only=True):
2828
def f(self):
2929
try:
30-
return self._cython_agg_general(alias)
30+
return self._cython_agg_general(alias, numeric_only=numeric_only)
3131
except Exception:
3232
return self.aggregate(lambda x: npfunc(x, axis=self.axis))
3333

@@ -350,8 +350,9 @@ def size(self):
350350
prod = _groupby_function('prod', 'prod', np.prod)
351351
min = _groupby_function('min', 'min', np.min)
352352
max = _groupby_function('max', 'max', np.max)
353-
first = _groupby_function('first', 'first', _first_compat)
354-
last = _groupby_function('last', 'last', _last_compat)
353+
first = _groupby_function('first', 'first', _first_compat,
354+
numeric_only=False)
355+
last = _groupby_function('last', 'last', _last_compat, numeric_only=False)
355356

356357
def ohlc(self):
357358
"""
@@ -370,10 +371,11 @@ def picker(arr):
370371
return np.nan
371372
return self.agg(picker)
372373

373-
def _cython_agg_general(self, how):
374+
def _cython_agg_general(self, how, numeric_only=True):
374375
output = {}
375376
for name, obj in self._iterate_slices():
376-
if not issubclass(obj.dtype.type, (np.number, np.bool_)):
377+
is_numeric = issubclass(obj.dtype.type, (np.number, np.bool_))
378+
if numeric_only and not is_numeric:
377379
continue
378380

379381
result, names = self.grouper.aggregate(obj.values, how)
@@ -668,6 +670,11 @@ def get_group_levels(self):
668670
'last': lib.group_last
669671
}
670672

673+
_cython_object_functions = {
674+
'first' : lambda a, b, c, d: lib.group_nth_object(a, b, c, d, 1),
675+
'last' : lib.group_last_object
676+
}
677+
671678
_cython_transforms = {
672679
'std' : np.sqrt
673680
}
@@ -681,7 +688,13 @@ def get_group_levels(self):
681688
_filter_empty_groups = True
682689

683690
def aggregate(self, values, how, axis=0):
684-
values = com._ensure_float64(values)
691+
values = com.ensure_float(values)
692+
is_numeric = True
693+
694+
if not issubclass(values.dtype.type, (np.number, np.bool_)):
695+
values = values.astype(object)
696+
is_numeric = False
697+
685698
arity = self._cython_arity.get(how, 1)
686699

687700
vdim = values.ndim
@@ -698,15 +711,19 @@ def aggregate(self, values, how, axis=0):
698711
out_shape = (self.ngroups,) + values.shape[1:]
699712

700713
# will be filled in Cython function
701-
result = np.empty(out_shape, dtype=np.float64)
714+
result = np.empty(out_shape, dtype=values.dtype)
702715
counts = np.zeros(self.ngroups, dtype=np.int64)
703716

704-
result = self._aggregate(result, counts, values, how)
717+
result = self._aggregate(result, counts, values, how, is_numeric)
705718

706719
if self._filter_empty_groups:
707720
if result.ndim == 2:
708-
result = lib.row_bool_subset(result,
709-
(counts > 0).view(np.uint8))
721+
if is_numeric:
722+
result = lib.row_bool_subset(result,
723+
(counts > 0).view(np.uint8))
724+
else:
725+
result = lib.row_bool_subset_object(result,
726+
(counts > 0).view(np.uint8))
710727
else:
711728
result = result[counts > 0]
712729

@@ -724,8 +741,11 @@ def aggregate(self, values, how, axis=0):
724741

725742
return result, names
726743

727-
def _aggregate(self, result, counts, values, how):
728-
agg_func = self._cython_functions[how]
744+
def _aggregate(self, result, counts, values, how, is_numeric):
745+
fdict = self._cython_functions
746+
if not is_numeric:
747+
fdict = self._cython_object_functions
748+
agg_func = fdict[how]
729749
trans_func = self._cython_transforms.get(how, lambda x: x)
730750

731751
comp_ids, _, ngroups = self.group_info
@@ -913,14 +933,22 @@ def names(self):
913933
'last': lib.group_last_bin
914934
}
915935

936+
_cython_object_functions = {
937+
'first' : lambda a, b, c, d: lib.group_nth_bin_object(a, b, c, d, 1),
938+
'last' : lib.group_last_bin_object
939+
}
940+
916941
_name_functions = {
917942
'ohlc' : lambda *args: ['open', 'high', 'low', 'close']
918943
}
919944

920945
_filter_empty_groups = True
921946

922-
def _aggregate(self, result, counts, values, how):
923-
agg_func = self._cython_functions[how]
947+
def _aggregate(self, result, counts, values, how, is_numeric=True):
948+
fdict = self._cython_functions
949+
if not is_numeric:
950+
fdict = self._cython_object_functions
951+
agg_func = fdict[how]
924952
trans_func = self._cython_transforms.get(how, lambda x: x)
925953

926954
if values.ndim > 3:
@@ -1385,8 +1413,8 @@ def _iterate_slices(self):
13851413

13861414
yield val, slicer(val)
13871415

1388-
def _cython_agg_general(self, how):
1389-
new_blocks = self._cython_agg_blocks(how)
1416+
def _cython_agg_general(self, how, numeric_only=True):
1417+
new_blocks = self._cython_agg_blocks(how, numeric_only=numeric_only)
13901418
return self._wrap_agged_blocks(new_blocks)
13911419

13921420
def _wrap_agged_blocks(self, blocks):
@@ -1408,18 +1436,20 @@ def _wrap_agged_blocks(self, blocks):
14081436

14091437
_block_agg_axis = 0
14101438

1411-
def _cython_agg_blocks(self, how):
1439+
def _cython_agg_blocks(self, how, numeric_only=True):
14121440
data, agg_axis = self._get_data_to_aggregate()
14131441

14141442
new_blocks = []
14151443

14161444
for block in data.blocks:
14171445
values = block.values
1418-
if not issubclass(values.dtype.type, (np.number, np.bool_)):
1446+
is_numeric = issubclass(values.dtype.type, (np.number, np.bool_))
1447+
if numeric_only and not is_numeric:
14191448
continue
14201449

1421-
values = com._ensure_float64(values)
1422-
result, names = self.grouper.aggregate(values, how, axis=agg_axis)
1450+
if is_numeric:
1451+
values = com.ensure_float(values)
1452+
result, _ = self.grouper.aggregate(values, how, axis=agg_axis)
14231453
newb = make_block(result, block.items, block.ref_items)
14241454
new_blocks.append(newb)
14251455

@@ -2210,5 +2240,3 @@ def complete_dataframe(obj, prev_completions):
22102240
install_ipython_completers()
22112241
except Exception:
22122242
pass
2213-
2214-

0 commit comments

Comments
 (0)