Skip to content

Commit 9bb210c

Browse files
committed
ENH: multi-GroupBy refactoring to be less nested, reuse more code, getting toward addressing #496
1 parent b16517e commit 9bb210c

File tree

3 files changed

+74
-178
lines changed

3 files changed

+74
-178
lines changed

pandas/core/groupby.py

Lines changed: 44 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,15 @@ def _multi_iter(self):
232232
elif isinstance(self.obj, Series):
233233
tipo = Series
234234

235-
def flatten(gen, level=0, shape_axis=0):
236-
ids = self.groupings[level].ids
237-
for cat, subgen in gen:
238-
if subgen is None:
239-
continue
240-
241-
if isinstance(subgen, tipo):
242-
yield (ids[cat],), subgen
243-
else:
244-
for subcat, data in flatten(subgen, level=level+1,
245-
shape_axis=shape_axis):
246-
yield (ids[cat],) + subcat, data
235+
id_list = [ping.ids for ping in self.groupings]
236+
shape = tuple(len(ids) for ids in id_list)
247237

248-
return flatten(self._generator_factory(data), shape_axis=self.axis)
238+
for label, group in self._generator_factory(data):
239+
if group is None:
240+
continue
241+
unraveled = np.unravel_index(label, shape)
242+
key = tuple(id_list[i][j] for i, j in enumerate(unraveled))
243+
yield key, group
249244

250245
def apply(self, func, *args, **kwargs):
251246
"""
@@ -387,51 +382,31 @@ def _python_agg_general(self, func, *args, **kwargs):
387382
group_shape = self._group_shape
388383
counts = np.zeros(group_shape, dtype=int)
389384

390-
# want to cythonize?
391-
def _doit(reschunk, ctchunk, gen, shape_axis=0):
392-
for i, (_, subgen) in enumerate(gen):
393-
# TODO: fixme
394-
if subgen is None:
385+
# todo: cythonize?
386+
def _aggregate(output, counts, generator, shape_axis=0):
387+
for label, group in generator:
388+
if group is None:
395389
continue
390+
counts[label] = group.shape[shape_axis]
391+
output[label] = func(group, *args, **kwargs)
396392

397-
if isinstance(subgen, PandasObject):
398-
size = subgen.shape[shape_axis]
399-
ctchunk[i] = size
400-
reschunk[i] = func(subgen, *args, **kwargs)
401-
else:
402-
_doit(reschunk[i], ctchunk[i], subgen,
403-
shape_axis=shape_axis)
404-
405-
gen_factory = self._generator_factory
406-
407-
try:
408-
stride_shape = self._agg_stride_shape
409-
output = np.empty(group_shape + stride_shape, dtype=float)
410-
output.fill(np.nan)
411-
obj = self._obj_with_exclusions
412-
_doit(output, counts, gen_factory(obj), shape_axis=self.axis)
413-
mask = counts.ravel() > 0
414-
output = output.reshape((np.prod(group_shape),) + stride_shape)
415-
output = output[mask]
416-
except Exception:
417-
# we failed, try to go slice-by-slice / column-by-column
418-
419-
result = np.empty(group_shape, dtype=float)
420-
result.fill(np.nan)
421-
# iterate through "columns" ex exclusions to populate output dict
422-
output = {}
423-
for name, obj in self._iterate_slices():
424-
try:
425-
_doit(result, counts, gen_factory(obj))
426-
# TODO: same mask for every column...
427-
output[name] = result.ravel().copy()
428-
result.fill(np.nan)
429-
except TypeError:
430-
continue
393+
result = np.empty(group_shape, dtype=float)
394+
result.fill(np.nan)
395+
# iterate through "columns" ex exclusions to populate output dict
396+
output = {}
397+
for name, obj in self._iterate_slices():
398+
try:
399+
_aggregate(result.ravel(), counts.ravel(),
400+
self._generator_factory(obj))
401+
# TODO: same mask for every column...
402+
output[name] = result.ravel().copy()
403+
result.fill(np.nan)
404+
except TypeError:
405+
continue
431406

432-
mask = counts.ravel() > 0
433-
for name, result in output.iteritems():
434-
output[name] = result[mask]
407+
mask = counts.ravel() > 0
408+
for name, result in output.iteritems():
409+
output[name] = result[mask]
435410

436411
return self._wrap_aggregated_output(output, mask)
437412

@@ -869,7 +844,7 @@ class DataFrameGroupBy(GroupBy):
869844
def _agg_stride_shape(self):
870845
if self._column is not None:
871846
# ffffff
872-
return 1
847+
return 1,
873848

874849
if self.axis == 0:
875850
n = len(self.obj.columns)
@@ -1322,8 +1297,14 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
13221297
-------
13231298
generator
13241299
"""
1325-
indexer = _get_group_sorter(label_list, shape)
1326-
sorted_labels = [labels.take(indexer) for labels in label_list]
1300+
group_index = get_group_index(label_list, shape)
1301+
na_mask = np.zeros(len(label_list[0]), dtype=bool)
1302+
for arr in label_list:
1303+
na_mask |= arr == -1
1304+
group_index[na_mask] = -1
1305+
indexer = lib.groupsort_indexer(group_index.astype('i4'),
1306+
np.prod(shape))
1307+
group_index = group_index.take(indexer)
13271308

13281309
if isinstance(data, BlockManager):
13291310
# this is sort of wasteful but...
@@ -1335,29 +1316,6 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
13351316
elif isinstance(data, DataFrame):
13361317
sorted_data = data.take(indexer, axis=axis)
13371318

1338-
gen = _generate_groups(sorted_data, sorted_labels, shape,
1339-
0, len(label_list[0]), axis=axis, which=0,
1340-
factory=factory)
1341-
for key, group in gen:
1342-
yield key, group
1343-
1344-
def _get_group_sorter(label_list, shape):
1345-
group_index = get_group_index(label_list, shape)
1346-
na_mask = np.zeros(len(label_list[0]), dtype=bool)
1347-
for arr in label_list:
1348-
na_mask |= arr == -1
1349-
group_index[na_mask] = -1
1350-
indexer = lib.groupsort_indexer(group_index.astype('i4'),
1351-
np.prod(shape))
1352-
1353-
return indexer
1354-
1355-
def _generate_groups(data, labels, shape, start, end, axis=0, which=0,
1356-
factory=lambda x: x):
1357-
axis_labels = labels[which][start:end]
1358-
edges = axis_labels.searchsorted(np.arange(1, shape[which] + 1),
1359-
side='left')
1360-
13611319
if isinstance(data, DataFrame):
13621320
def slicer(data, slob):
13631321
if axis == 0:
@@ -1371,29 +1329,13 @@ def slicer(data, slob):
13711329
def slicer(data, slob):
13721330
return data[slob]
13731331

1374-
do_slice = which == len(labels) - 1
1332+
starts, ends = lib.generate_slices(group_index, np.prod(shape))
13751333

1376-
# omit -1 values at beginning-- NA values
1377-
left = axis_labels.searchsorted(0)
1378-
1379-
# time to actually aggregate
1380-
for i, right in enumerate(edges):
1381-
if do_slice:
1382-
slob = slice(start + left, start + right)
1383-
1384-
# skip empty groups in the cartesian product
1385-
if left == right:
1386-
yield i, None
1387-
continue
1388-
1389-
yield i, slicer(data, slob)
1334+
for i, (start, end) in enumerate(zip(starts, ends)):
1335+
if start == end:
1336+
yield i, None
13901337
else:
1391-
# yield subgenerators, yikes
1392-
yield i, _generate_groups(data, labels, shape, start + left,
1393-
start + right, axis=axis,
1394-
which=which + 1, factory=factory)
1395-
1396-
left = right
1338+
yield i, slicer(sorted_data, slice(start, end))
13971339

13981340
def get_group_index(label_list, shape):
13991341
n = len(label_list[0])

pandas/src/groupby.pyx

Lines changed: 28 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -261,44 +261,6 @@ def get_unique_labels(ndarray[object] values, dict idMap):
261261

262262
return fillVec
263263

264-
# from libcpp.set cimport set as stlset
265-
266-
# cdef fast_unique_int32(ndarray arr):
267-
# cdef:
268-
# cdef stlset[int] table
269-
270-
# Py_ssize_t i, n = len(arr)
271-
# int32_t* values
272-
# list uniques = []
273-
# int32_t val
274-
275-
# values = <int32_t*> arr.data
276-
277-
# for i from 0 <= i < n:
278-
# val = values[i]
279-
# if table.count(val) == 0:
280-
# table.insert(val)
281-
# uniques.append(val)
282-
# return np.asarray(sorted(uniques), dtype=object)
283-
284-
285-
def _group_reorder(values, label_list, shape):
286-
# group_index = np.zeros(len(label_list[0]), dtype='i4')
287-
# for i in xrange(len(shape)):
288-
# stride = np.prod([x for x in shape[i+1:]], dtype='i4')
289-
# group_index += label_list[i] * stride
290-
# na_mask = np.zeros(len(label_list[0]), dtype=bool)
291-
# for arr in label_list:
292-
# na_mask |= arr == -1
293-
# group_index[na_mask] = -1
294-
295-
# indexer = groupsort_indexer(group_index, np.prod(shape))
296-
297-
indexer = np.lexsort(label_list[::-1])
298-
sorted_labels = [labels.take(indexer) for labels in label_list]
299-
sorted_values = values.take(indexer)
300-
return sorted_values, sorted_labels
301-
302264
@cython.wraparound(False)
303265
def groupsort_indexer(ndarray[int32_t] index, Py_ssize_t ngroups):
304266
cdef:
@@ -326,39 +288,10 @@ def groupsort_indexer(ndarray[int32_t] index, Py_ssize_t ngroups):
326288
return result
327289

328290

329-
# cdef int _aggregate_group(float64_t *out, int32_t *counts, float64_t *values,
330-
# list labels, int start, int end, tuple shape,
331-
# Py_ssize_t which, Py_ssize_t offset,
332-
# agg_func func) except -1:
333-
# cdef:
334-
# ndarray[int32_t] axis
335-
# cdef Py_ssize_t stride
336-
337-
# # time to actually aggregate
338-
# if which == len(labels) - 1:
339-
# axis = labels[which]
340-
341-
# while start < end and axis[start] == -1:
342-
# start += 1
343-
# func(out, counts, values, <int32_t*> axis.data, start, end, offset)
344-
# else:
345-
# axis = labels[which][start:end]
346-
# stride = np.prod(shape[which+1:])
347-
# # get group counts on axisp
348-
# edges = axis.searchsorted(np.arange(1, shape[which] + 1), side='left')
349-
# # print edges, axis
350-
351-
# left = axis.searchsorted(0) # ignore NA values coded as -1
352-
353-
# # aggregate each subgroup
354-
# for right in edges:
355-
# _aggregate_group(out, counts, values, labels, start + left,
356-
# start + right, shape, which + 1, offset, func)
357-
# offset += stride
358-
# left = right
359291

360292
# TODO: aggregate multiple columns in single pass
361293

294+
@cython.boundscheck(False)
362295
@cython.wraparound(False)
363296
def group_add(ndarray[float64_t] out,
364297
ndarray[int32_t] counts,
@@ -391,6 +324,7 @@ def group_add(ndarray[float64_t] out,
391324
else:
392325
out[i] = sumx[i]
393326

327+
@cython.boundscheck(False)
394328
@cython.wraparound(False)
395329
def group_mean(ndarray[float64_t] out,
396330
ndarray[int32_t] counts,
@@ -424,6 +358,7 @@ def group_mean(ndarray[float64_t] out,
424358
else:
425359
out[i] = sumx[i] / count
426360

361+
@cython.boundscheck(False)
427362
@cython.wraparound(False)
428363
def group_var(ndarray[float64_t] out,
429364
ndarray[int32_t] counts,
@@ -460,13 +395,6 @@ def group_var(ndarray[float64_t] out,
460395
out[i] = ((ct * sumxx[i] - sumx[i] * sumx[i]) /
461396
(ct * ct - ct))
462397

463-
def _result_shape(label_list):
464-
# assumed sorted
465-
shape = []
466-
for labels in label_list:
467-
shape.append(1 + labels[-1])
468-
return tuple(shape)
469-
470398
def reduce_mean(ndarray[object] indices,
471399
ndarray[object] buckets,
472400
ndarray[float64_t] values,
@@ -584,6 +512,31 @@ def duplicated(list values, take_last=False):
584512

585513
return result.view(np.bool_)
586514

515+
516+
def generate_slices(ndarray[Py_ssize_t] labels, Py_ssize_t ngroups):
517+
cdef:
518+
Py_ssize_t i, group_size, n, lab, start
519+
object slobj
520+
ndarray[int32_t] starts
521+
522+
n = len(labels)
523+
524+
starts = np.zeros(ngroups, dtype='i4')
525+
ends = np.zeros(ngroups, dtype='i4')
526+
527+
start = 0
528+
group_size = 0
529+
for i in range(n):
530+
group_size += 1
531+
lab = labels[i]
532+
if i == n - 1 or lab != labels[i + 1]:
533+
starts[lab] = start
534+
ends[lab] = start + group_size
535+
start += group_size
536+
group_size = 0
537+
538+
return starts, ends
539+
587540
'''
588541
589542
def ts_upsample_mean(ndarray[object] indices,

pandas/src/reduce.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ cdef class Grouper:
9999
object arr, dummy, f, labels, counts
100100
bint passed_dummy
101101

102-
def __init__(self, object arr, object f, object labels, ngroups, dummy=None):
102+
def __init__(self, object arr, object index, object f,
103+
object labels, ngroups, dummy=None):
103104
n = len(arr)
104105

105106
assert(arr.ndim == 1)

0 commit comments

Comments
 (0)