@@ -232,20 +232,15 @@ def _multi_iter(self):
232
232
elif isinstance (self .obj , Series ):
233
233
tipo = Series
234
234
235
- def flatten (gen , level = 0 , shape_axis = 0 ):
236
- ids = self .groupings [level ].ids
237
- for cat , subgen in gen :
238
- if subgen is None :
239
- continue
240
-
241
- if isinstance (subgen , tipo ):
242
- yield (ids [cat ],), subgen
243
- else :
244
- for subcat , data in flatten (subgen , level = level + 1 ,
245
- shape_axis = shape_axis ):
246
- yield (ids [cat ],) + subcat , data
235
+ id_list = [ping .ids for ping in self .groupings ]
236
+ shape = tuple (len (ids ) for ids in id_list )
247
237
248
- return flatten (self ._generator_factory (data ), shape_axis = self .axis )
238
+ for label , group in self ._generator_factory (data ):
239
+ if group is None :
240
+ continue
241
+ unraveled = np .unravel_index (label , shape )
242
+ key = tuple (id_list [i ][j ] for i , j in enumerate (unraveled ))
243
+ yield key , group
249
244
250
245
def apply (self , func , * args , ** kwargs ):
251
246
"""
@@ -387,51 +382,31 @@ def _python_agg_general(self, func, *args, **kwargs):
387
382
group_shape = self ._group_shape
388
383
counts = np .zeros (group_shape , dtype = int )
389
384
390
- # want to cythonize?
391
- def _doit (reschunk , ctchunk , gen , shape_axis = 0 ):
392
- for i , (_ , subgen ) in enumerate (gen ):
393
- # TODO: fixme
394
- if subgen is None :
385
+ # todo: cythonize?
386
+ def _aggregate (output , counts , generator , shape_axis = 0 ):
387
+ for label , group in generator :
388
+ if group is None :
395
389
continue
390
+ counts [label ] = group .shape [shape_axis ]
391
+ output [label ] = func (group , * args , ** kwargs )
396
392
397
- if isinstance (subgen , PandasObject ):
398
- size = subgen .shape [shape_axis ]
399
- ctchunk [i ] = size
400
- reschunk [i ] = func (subgen , * args , ** kwargs )
401
- else :
402
- _doit (reschunk [i ], ctchunk [i ], subgen ,
403
- shape_axis = shape_axis )
404
-
405
- gen_factory = self ._generator_factory
406
-
407
- try :
408
- stride_shape = self ._agg_stride_shape
409
- output = np .empty (group_shape + stride_shape , dtype = float )
410
- output .fill (np .nan )
411
- obj = self ._obj_with_exclusions
412
- _doit (output , counts , gen_factory (obj ), shape_axis = self .axis )
413
- mask = counts .ravel () > 0
414
- output = output .reshape ((np .prod (group_shape ),) + stride_shape )
415
- output = output [mask ]
416
- except Exception :
417
- # we failed, try to go slice-by-slice / column-by-column
418
-
419
- result = np .empty (group_shape , dtype = float )
420
- result .fill (np .nan )
421
- # iterate through "columns" ex exclusions to populate output dict
422
- output = {}
423
- for name , obj in self ._iterate_slices ():
424
- try :
425
- _doit (result , counts , gen_factory (obj ))
426
- # TODO: same mask for every column...
427
- output [name ] = result .ravel ().copy ()
428
- result .fill (np .nan )
429
- except TypeError :
430
- continue
393
+ result = np .empty (group_shape , dtype = float )
394
+ result .fill (np .nan )
395
+ # iterate through "columns" ex exclusions to populate output dict
396
+ output = {}
397
+ for name , obj in self ._iterate_slices ():
398
+ try :
399
+ _aggregate (result .ravel (), counts .ravel (),
400
+ self ._generator_factory (obj ))
401
+ # TODO: same mask for every column...
402
+ output [name ] = result .ravel ().copy ()
403
+ result .fill (np .nan )
404
+ except TypeError :
405
+ continue
431
406
432
- mask = counts .ravel () > 0
433
- for name , result in output .iteritems ():
434
- output [name ] = result [mask ]
407
+ mask = counts .ravel () > 0
408
+ for name , result in output .iteritems ():
409
+ output [name ] = result [mask ]
435
410
436
411
return self ._wrap_aggregated_output (output , mask )
437
412
@@ -869,7 +844,7 @@ class DataFrameGroupBy(GroupBy):
869
844
def _agg_stride_shape (self ):
870
845
if self ._column is not None :
871
846
# ffffff
872
- return 1
847
+ return 1 ,
873
848
874
849
if self .axis == 0 :
875
850
n = len (self .obj .columns )
@@ -1322,8 +1297,14 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
1322
1297
-------
1323
1298
generator
1324
1299
"""
1325
- indexer = _get_group_sorter (label_list , shape )
1326
- sorted_labels = [labels .take (indexer ) for labels in label_list ]
1300
+ group_index = get_group_index (label_list , shape )
1301
+ na_mask = np .zeros (len (label_list [0 ]), dtype = bool )
1302
+ for arr in label_list :
1303
+ na_mask |= arr == - 1
1304
+ group_index [na_mask ] = - 1
1305
+ indexer = lib .groupsort_indexer (group_index .astype ('i4' ),
1306
+ np .prod (shape ))
1307
+ group_index = group_index .take (indexer )
1327
1308
1328
1309
if isinstance (data , BlockManager ):
1329
1310
# this is sort of wasteful but...
@@ -1335,29 +1316,6 @@ def generate_groups(data, label_list, shape, axis=0, factory=lambda x: x):
1335
1316
elif isinstance (data , DataFrame ):
1336
1317
sorted_data = data .take (indexer , axis = axis )
1337
1318
1338
- gen = _generate_groups (sorted_data , sorted_labels , shape ,
1339
- 0 , len (label_list [0 ]), axis = axis , which = 0 ,
1340
- factory = factory )
1341
- for key , group in gen :
1342
- yield key , group
1343
-
1344
- def _get_group_sorter (label_list , shape ):
1345
- group_index = get_group_index (label_list , shape )
1346
- na_mask = np .zeros (len (label_list [0 ]), dtype = bool )
1347
- for arr in label_list :
1348
- na_mask |= arr == - 1
1349
- group_index [na_mask ] = - 1
1350
- indexer = lib .groupsort_indexer (group_index .astype ('i4' ),
1351
- np .prod (shape ))
1352
-
1353
- return indexer
1354
-
1355
- def _generate_groups (data , labels , shape , start , end , axis = 0 , which = 0 ,
1356
- factory = lambda x : x ):
1357
- axis_labels = labels [which ][start :end ]
1358
- edges = axis_labels .searchsorted (np .arange (1 , shape [which ] + 1 ),
1359
- side = 'left' )
1360
-
1361
1319
if isinstance (data , DataFrame ):
1362
1320
def slicer (data , slob ):
1363
1321
if axis == 0 :
@@ -1371,29 +1329,13 @@ def slicer(data, slob):
1371
1329
def slicer (data , slob ):
1372
1330
return data [slob ]
1373
1331
1374
- do_slice = which == len ( labels ) - 1
1332
+ starts , ends = lib . generate_slices ( group_index , np . prod ( shape ))
1375
1333
1376
- # omit -1 values at beginning-- NA values
1377
- left = axis_labels .searchsorted (0 )
1378
-
1379
- # time to actually aggregate
1380
- for i , right in enumerate (edges ):
1381
- if do_slice :
1382
- slob = slice (start + left , start + right )
1383
-
1384
- # skip empty groups in the cartesian product
1385
- if left == right :
1386
- yield i , None
1387
- continue
1388
-
1389
- yield i , slicer (data , slob )
1334
+ for i , (start , end ) in enumerate (zip (starts , ends )):
1335
+ if start == end :
1336
+ yield i , None
1390
1337
else :
1391
- # yield subgenerators, yikes
1392
- yield i , _generate_groups (data , labels , shape , start + left ,
1393
- start + right , axis = axis ,
1394
- which = which + 1 , factory = factory )
1395
-
1396
- left = right
1338
+ yield i , slicer (sorted_data , slice (start , end ))
1397
1339
1398
1340
def get_group_index (label_list , shape ):
1399
1341
n = len (label_list [0 ])
0 commit comments