Skip to content

Commit 095f7c8

Browse files
committed
ENH: optimized Cython groupby routines for aggregating 2D blocks, added vbenchmark, GH #745
1 parent 922d041 commit 095f7c8

File tree

7 files changed

+259
-75
lines changed

7 files changed

+259
-75
lines changed

RELEASE.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ pandas 0.7.0
175175
- Add option to Series.to_csv to omit the index (PR #684)
176176
- Add ``delimiter`` as an alternative to ``sep`` in ``read_csv`` and other
177177
parsing functions
178+
- Substantially improved performance of groupby on DataFrames with many
179+
columns by aggregating blocks of columns all at once (GH #745)
178180

179181
**Bug fixes**
180182

pandas/core/common.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,28 @@ def is_float_dtype(arr_or_dtype):
531531
tipo = arr_or_dtype.dtype.type
532532
return issubclass(tipo, np.floating)
533533

534+
535+
def _ensure_float64(arr):
536+
if arr.dtype != np.float64:
537+
arr = arr.astype(np.float64)
538+
return arr
539+
540+
def _ensure_int64(arr):
541+
if arr.dtype != np.int64:
542+
arr = arr.astype(np.int64)
543+
return arr
544+
545+
def _ensure_int32(arr):
546+
if arr.dtype != np.int32:
547+
arr = arr.astype(np.int32)
548+
return arr
549+
550+
def _ensure_object(arr):
551+
if arr.dtype != np.object_:
552+
arr = arr.astype('O')
553+
return arr
554+
555+
534556
def save(obj, path):
535557
"""
536558
Pickle (serialize) object to input file path

pandas/core/groupby.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pandas.core.frame import DataFrame
77
from pandas.core.generic import NDFrame
88
from pandas.core.index import Index, MultiIndex
9-
from pandas.core.internals import BlockManager
9+
from pandas.core.internals import BlockManager, make_block
1010
from pandas.core.series import Series
1111
from pandas.core.panel import Panel
1212
from pandas.util.decorators import cache_readonly, Appender
@@ -156,7 +156,7 @@ def _group_shape(self):
156156
return tuple(ping.ngroups for ping in self.groupings)
157157

158158
def __getattr__(self, attr):
159-
if hasattr(self.obj, attr):
159+
if hasattr(self.obj, attr) and attr != '_cache':
160160
return self._make_wrapper(attr)
161161
raise AttributeError("'%s' object has no attribute '%s'" %
162162
(type(self).__name__, attr))
@@ -352,9 +352,7 @@ def _cython_agg_general(self, how):
352352
if not issubclass(obj.dtype.type, (np.number, np.bool_)):
353353
continue
354354

355-
if obj.dtype != np.float64:
356-
obj = obj.astype('f8')
357-
355+
obj = com._ensure_float64(obj)
358356
result, counts = cython_aggregate(obj, comp_ids,
359357
max_group, how=how)
360358
mask = counts > 0
@@ -395,10 +393,7 @@ def _python_agg_general(self, func, *args, **kwargs):
395393
def _group_index(self):
396394
result = get_group_index([ping.labels for ping in self.groupings],
397395
self._group_shape)
398-
399-
if result.dtype != np.int64: # pragma: no cover
400-
result = result.astype('i8')
401-
return result
396+
return com._ensure_int64(result)
402397

403398
def _get_multi_index(self, mask, obs_ids):
404399
masked = [labels for _, labels in
@@ -642,9 +637,7 @@ def _make_labels(self):
642637
if self._was_factor: # pragma: no cover
643638
raise Exception('Should not call this method grouping by level')
644639
else:
645-
values = self.grouper
646-
if values.dtype != np.object_:
647-
values = values.astype('O')
640+
values = com._ensure_object(self.grouper)
648641

649642
# khash
650643
rizer = lib.Factorizer(len(values))
@@ -955,6 +948,73 @@ def _iterate_slices(self):
955948

956949
yield val, slicer(val)
957950

951+
952+
def _cython_agg_general(self, how):
953+
954+
group_index = self._group_index
955+
comp_ids, obs_group_ids = _compress_group_index(group_index)
956+
max_group = len(obs_group_ids)
957+
958+
obj = self._obj_with_exclusions
959+
if self.axis == 1:
960+
obj = obj.T
961+
962+
new_blocks = []
963+
964+
for block in obj._data.blocks:
965+
values = block.values.T
966+
if not issubclass(values.dtype.type, (np.number, np.bool_)):
967+
continue
968+
969+
values = com._ensure_float64(values)
970+
result, counts = cython_aggregate(values, comp_ids,
971+
max_group, how=how)
972+
973+
mask = counts > 0
974+
if len(mask) > 0:
975+
result = result[mask]
976+
newb = make_block(result.T, block.items, block.ref_items)
977+
new_blocks.append(newb)
978+
979+
if len(new_blocks) == 0:
980+
raise GroupByError('No numeric types to aggregate')
981+
982+
agg_axis = 0 if self.axis == 1 else 1
983+
agg_labels = self._obj_with_exclusions._get_axis(agg_axis)
984+
985+
if sum(len(x.items) for x in new_blocks) == len(agg_labels):
986+
output_keys = agg_labels
987+
else:
988+
output_keys = []
989+
for b in new_blocks:
990+
output_keys.extend(b.items)
991+
try:
992+
output_keys.sort()
993+
except TypeError: # pragma
994+
pass
995+
996+
if isinstance(agg_labels, MultiIndex):
997+
output_keys = MultiIndex.from_tuples(output_keys,
998+
names=agg_labels.names)
999+
1000+
if not self.as_index:
1001+
index = np.arange(new_blocks[0].values.shape[1])
1002+
mgr = BlockManager(new_blocks, [output_keys, index])
1003+
result = DataFrame(mgr)
1004+
group_levels = self._get_group_levels(mask, obs_group_ids)
1005+
for i, (name, labels) in enumerate(group_levels):
1006+
result.insert(i, name, labels)
1007+
result = result.consolidate()
1008+
else:
1009+
index = self._get_multi_index(mask, obs_group_ids)
1010+
mgr = BlockManager(new_blocks, [output_keys, index])
1011+
result = DataFrame(mgr)
1012+
1013+
if self.axis == 1:
1014+
result = result.T
1015+
1016+
return result
1017+
9581018
@cache_readonly
9591019
def _obj_with_exclusions(self):
9601020
if self._column is not None:
@@ -1282,8 +1342,9 @@ def generate_groups(data, group_index, ngroups, axis=0, factory=lambda x: x):
12821342
-------
12831343
generator
12841344
"""
1285-
indexer = lib.groupsort_indexer(group_index.astype('i4'),
1286-
ngroups)[0]
1345+
group_index = com._ensure_int32(group_index)
1346+
1347+
indexer = lib.groupsort_indexer(group_index, ngroups)[0]
12871348
group_index = group_index.take(indexer)
12881349

12891350
if isinstance(data, BlockManager):
@@ -1312,8 +1373,7 @@ def _get_slice(slob):
13121373
def _get_slice(slob):
13131374
return sorted_data[slob]
13141375

1315-
starts, ends = lib.generate_slices(group_index.astype('i4'),
1316-
ngroups)
1376+
starts, ends = lib.generate_slices(group_index, ngroups)
13171377

13181378
for i, (start, end) in enumerate(zip(starts, ends)):
13191379
# Since I'm now compressing the group ids, it's now not "possible" to
@@ -1385,14 +1445,27 @@ def get_key(self, comp_id):
13851445

13861446
def cython_aggregate(values, group_index, ngroups, how='add'):
13871447
agg_func = _cython_functions[how]
1448+
if values.ndim == 1:
1449+
squeeze = True
1450+
values = values[:, None]
1451+
out_shape = (ngroups, 1)
1452+
else:
1453+
squeeze = False
1454+
out_shape = (ngroups, values.shape[1])
1455+
13881456
trans_func = _cython_transforms.get(how, lambda x: x)
13891457

1390-
result = np.empty(ngroups, dtype=np.float64)
1458+
result = np.empty(out_shape, dtype=np.float64)
13911459
result.fill(np.nan)
13921460

13931461
counts = np.zeros(ngroups, dtype=np.int32)
1462+
13941463
agg_func(result, counts, values, group_index)
13951464
result = trans_func(result)
1465+
1466+
if squeeze:
1467+
result = result.squeeze()
1468+
13961469
return result, counts
13971470

13981471
_cython_functions = {

0 commit comments

Comments
 (0)