Skip to content

Commit a57bd48

Browse files
committed
ENH: refactored Cython groupby code to not have to sort, GH #93
1 parent 09553cc commit a57bd48

File tree

5 files changed

+197
-152
lines changed

5 files changed

+197
-152
lines changed

pandas/core/groupby.py

+66-12
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pandas.core.generic import NDFrame, PandasObject
88
from pandas.core.index import Index, MultiIndex
99
from pandas.core.internals import BlockManager
10-
from pandas.core.reshape import get_group_index
1110
from pandas.core.series import Series
1211
from pandas.core.panel import Panel
1312
from pandas.util.decorators import cache_readonly
@@ -316,6 +315,14 @@ def mean(self):
316315
"""
317316
return self._cython_agg_general('mean')
318317

318+
def std(self):
319+
"""
320+
Compute mean of groups, excluding missing values
321+
322+
For multiple groupings, the result index will be a MultiIndex
323+
"""
324+
return self._cython_agg_general('std')
325+
319326
def size(self):
320327
"""
321328
Compute group sizes
@@ -356,8 +363,8 @@ def _cython_agg_general(self, how):
356363
else:
357364
continue
358365

359-
result, counts = lib.group_aggregate(obj, label_list,
360-
shape, how=how)
366+
result, counts = cython_aggregate(obj, label_list,
367+
shape, how=how)
361368
result = result.ravel()
362369
mask = counts.ravel() > 0
363370
output[name] = result[mask]
@@ -1315,15 +1322,7 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
13151322
-------
13161323
generator
13171324
"""
1318-
# indexer = np.lexsort(label_list[::-1])
1319-
group_index = get_group_index(label_list, shape)
1320-
na_mask = np.zeros(len(label_list[0]), dtype=bool)
1321-
for arr in label_list:
1322-
na_mask |= arr == -1
1323-
group_index[na_mask] = -1
1324-
indexer = lib.groupsort_indexer(group_index.astype('i4'),
1325-
np.prod(shape))
1326-
1325+
indexer = _get_group_sorter(label_list, shape)
13271326
sorted_labels = [labels.take(indexer) for labels in label_list]
13281327

13291328
if isinstance(data, BlockManager):
@@ -1342,6 +1341,17 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
13421341
for key, group in gen:
13431342
yield key, group
13441343

1344+
def _get_group_sorter(label_list, shape):
1345+
group_index = get_group_index(label_list, shape)
1346+
na_mask = np.zeros(len(label_list[0]), dtype=bool)
1347+
for arr in label_list:
1348+
na_mask |= arr == -1
1349+
group_index[na_mask] = -1
1350+
indexer = lib.groupsort_indexer(group_index.astype('i4'),
1351+
np.prod(shape))
1352+
1353+
return indexer
1354+
13451355
def _generate_groups(data, labels, shape, start, end, axis=0, which=0,
13461356
factory=lambda x: x):
13471357
axis_labels = labels[which][start:end]
@@ -1385,6 +1395,50 @@ def slicer(data, slob):
13851395

13861396
left = right
13871397

1398+
def get_group_index(label_list, shape):
1399+
n = len(label_list[0])
1400+
group_index = np.zeros(n, dtype=int)
1401+
mask = np.zeros(n, dtype=bool)
1402+
for i in xrange(len(shape)):
1403+
stride = np.prod([x for x in shape[i+1:]], dtype=int)
1404+
group_index += label_list[i] * stride
1405+
mask |= label_list[i] < 0
1406+
1407+
np.putmask(group_index, mask, -1)
1408+
return group_index
1409+
1410+
#----------------------------------------------------------------------
1411+
# Group aggregations in Cython
1412+
1413+
1414+
def cython_aggregate(values, label_list, shape, how='add'):
1415+
agg_func = _cython_functions[how]
1416+
trans_func = _cython_transforms.get(how, lambda x: x)
1417+
1418+
group_index = get_group_index(label_list, shape).astype('i4')
1419+
1420+
result = np.empty(shape, dtype=np.float64)
1421+
result.fill(np.nan)
1422+
1423+
counts = np.zeros(shape, dtype=np.int32)
1424+
agg_func(result.ravel(), counts.ravel(), values,
1425+
group_index)
1426+
1427+
result = trans_func(result)
1428+
1429+
return result, counts
1430+
1431+
_cython_functions = {
1432+
'add' : lib.group_add,
1433+
'mean' : lib.group_mean,
1434+
'var' : lib.group_var,
1435+
'std' : lib.group_var
1436+
}
1437+
1438+
_cython_transforms = {
1439+
'std' : np.sqrt
1440+
}
1441+
13881442
#----------------------------------------------------------------------
13891443
# sorting levels...cleverly?
13901444

pandas/core/reshape.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.core.series import Series
1010

1111
from pandas.core.common import notnull
12+
from pandas.core.groupby import get_group_index
1213
from pandas.core.index import MultiIndex
1314

1415

@@ -197,13 +198,6 @@ def get_new_index(self):
197198

198199
return new_index
199200

200-
def get_group_index(label_list, shape):
201-
group_index = np.zeros(len(label_list[0]), dtype=int)
202-
for i in xrange(len(shape)):
203-
stride = np.prod([x for x in shape[i+1:]], dtype=int)
204-
group_index += label_list[i] * stride
205-
return group_index
206-
207201
def pivot(self, index=None, columns=None, values=None):
208202
"""
209203
See DataFrame.pivot

0 commit comments

Comments
 (0)