@@ -413,42 +413,54 @@ def indices(self):
413
413
""" dict {group name -> group indices} """
414
414
return self .grouper .indices
415
415
416
- def _get_index (self , name ):
417
- """ safe get index , translate keys for datelike to underlying repr """
416
+ def _get_indices (self , names , raise_on_missing = False ):
417
+ """ safe get multiple indices , translate keys for datelike to underlying repr """
418
418
419
- def convert ( key , s ):
419
+ def get_converter ( s ):
420
420
# possibly convert to they actual key types
421
421
# in the indices, could be a Timestamp or a np.datetime64
422
-
423
422
if isinstance (s , (Timestamp ,datetime .datetime )):
424
- return Timestamp (key )
423
+ return lambda key : Timestamp (key )
425
424
elif isinstance (s , np .datetime64 ):
426
- return Timestamp (key ).asm8
427
- return key
425
+ return lambda key : Timestamp (key ).asm8
426
+ else :
427
+ return lambda key : key
428
428
429
- sample = next (iter (self .indices ))
430
- if isinstance (sample , tuple ):
431
- if not isinstance (name , tuple ):
429
+ if len (names ) == 0 :
430
+ return []
431
+
432
+ index_sample = next (iter (self .indices ))
433
+ name_sample = names [0 ]
434
+ if isinstance (index_sample , tuple ):
435
+ if not isinstance (name_sample , tuple ):
432
436
msg = ("must supply a tuple to get_group with multiple"
433
437
" grouping keys" )
434
438
raise ValueError (msg )
435
- if not len (name ) == len (sample ):
439
+ if not len (name_sample ) == len (index_sample ):
436
440
try :
437
441
# If the original grouper was a tuple
438
- return self .indices [name ]
442
+ return [ self .indices [name ] for name in names ]
439
443
except KeyError :
440
444
# turns out it wasn't a tuple
441
445
msg = ("must supply a a same-length tuple to get_group"
442
446
" with multiple grouping keys" )
443
447
raise ValueError (msg )
444
448
445
- name = tuple ([ convert (n , k ) for n , k in zip (name ,sample ) ])
449
+ converters = [get_converter (s ) for s in index_sample ]
450
+ names = [tuple ([f (n ) for f , n in zip (converters , name )]) for name in names ]
446
451
447
452
else :
453
+ converter = get_converter (index_sample )
454
+ names = [converter (name ) for name in names ]
448
455
449
- name = convert (name , sample )
456
+ if raise_on_missing :
457
+ return [self .indices [name ] for name in names ]
458
+ else :
459
+ return [self .indices .get (name , []) for name in names ]
450
460
451
- return self .indices [name ]
461
+ def _get_index (self , name , raise_on_missing = False ):
462
+ """ safe get index, translate keys for datelike to underlying repr """
463
+ return self ._get_indices ([name ], raise_on_missing )[0 ]
452
464
453
465
@property
454
466
def name (self ):
@@ -494,7 +506,7 @@ def _set_result_index_ordered(self, result):
494
506
495
507
# shortcut of we have an already ordered grouper
496
508
if not self .grouper .is_monotonic :
497
- index = Index (np .concatenate ([ indices . get ( v , []) for v in self .grouper .result_index ] ))
509
+ index = Index (np .concatenate (self . _get_indices ( self .grouper .result_index ) ))
498
510
result .index = index
499
511
result = result .sort_index ()
500
512
@@ -598,7 +610,7 @@ def get_group(self, name, obj=None):
598
610
if obj is None :
599
611
obj = self ._selected_obj
600
612
601
- inds = self ._get_index (name )
613
+ inds = self ._get_index (name , raise_on_missing = True )
602
614
return obj .take (inds , axis = self .axis , convert = False )
603
615
604
616
def __iter__ (self ):
@@ -2445,9 +2457,6 @@ def transform(self, func, *args, **kwargs):
2445
2457
2446
2458
wrapper = lambda x : func (x , * args , ** kwargs )
2447
2459
for i , (name , group ) in enumerate (self ):
2448
- if name not in self .indices :
2449
- continue
2450
-
2451
2460
object .__setattr__ (group , 'name' , name )
2452
2461
res = wrapper (group )
2453
2462
@@ -2462,7 +2471,7 @@ def transform(self, func, *args, **kwargs):
2462
2471
except :
2463
2472
pass
2464
2473
2465
- indexer = self .indices [ name ]
2474
+ indexer = self ._get_index ( name )
2466
2475
result [indexer ] = res
2467
2476
2468
2477
result = _possibly_downcast_to_dtype (result , dtype )
@@ -2516,11 +2525,8 @@ def true_and_notnull(x, *args, **kwargs):
2516
2525
return b and notnull (b )
2517
2526
2518
2527
try :
2519
- indices = []
2520
- for name , group in self :
2521
- if true_and_notnull (group ) and name in self .indices :
2522
- indices .append (self .indices [name ])
2523
-
2528
+ indices = [self ._get_index (name ) if true_and_notnull (group )
2529
+ for name , group in self ]
2524
2530
except ValueError :
2525
2531
raise TypeError ("the filter must return a boolean result" )
2526
2532
except TypeError :
@@ -3040,8 +3046,8 @@ def transform(self, func, *args, **kwargs):
3040
3046
results = np .empty_like (obj .values , result .values .dtype )
3041
3047
indices = self .indices
3042
3048
for (name , group ), (i , row ) in zip (self , result .iterrows ()):
3043
- if name in indices :
3044
- indexer = indices [ name ]
3049
+ indexer = self . _get_index ( name )
3050
+ if len ( indexer ) > 0 :
3045
3051
results [indexer ] = np .tile (row .values ,len (indexer )).reshape (len (indexer ),- 1 )
3046
3052
3047
3053
counts = self .size ().fillna (0 ).values
@@ -3141,8 +3147,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
3141
3147
3142
3148
# interpret the result of the filter
3143
3149
if is_bool (res ) or (lib .isscalar (res ) and isnull (res )):
3144
- if res and notnull (res ) and name in self . indices :
3145
- indices .append (self .indices [ name ] )
3150
+ if res and notnull (res ):
3151
+ indices .append (self ._get_index ( name ) )
3146
3152
else :
3147
3153
# non scalars aren't allowed
3148
3154
raise TypeError ("filter function returned a %s, "
0 commit comments