@@ -422,46 +422,55 @@ def indices(self):
422
422
""" dict {group name -> group indices} """
423
423
return self .grouper .indices
424
424
425
- def _get_index (self , name ):
426
- """ safe get index , translate keys for datelike to underlying repr """
425
+ def _get_indices (self , names ):
426
+ """ safe get multiple indices , translate keys for datelike to underlying repr """
427
427
428
- def convert ( key , s ):
428
+ def get_converter ( s ):
429
429
# possibly convert to they actual key types
430
430
# in the indices, could be a Timestamp or a np.datetime64
431
-
432
431
if isinstance (s , (Timestamp ,datetime .datetime )):
433
- return Timestamp (key )
432
+ return lambda key : Timestamp (key )
434
433
elif isinstance (s , np .datetime64 ):
435
- return Timestamp (key ).asm8
436
- return key
434
+ return lambda key : Timestamp (key ).asm8
435
+ else :
436
+ return lambda key : key
437
+
438
+ if len (names ) == 0 :
439
+ return []
437
440
438
441
if len (self .indices ) > 0 :
439
- sample = next (iter (self .indices ))
442
+ index_sample = next (iter (self .indices ))
440
443
else :
441
- sample = None # Dummy sample
444
+ index_sample = None # Dummy sample
442
445
443
- if isinstance (sample , tuple ):
444
- if not isinstance (name , tuple ):
446
+ name_sample = names [0 ]
447
+ if isinstance (index_sample , tuple ):
448
+ if not isinstance (name_sample , tuple ):
445
449
msg = ("must supply a tuple to get_group with multiple"
446
450
" grouping keys" )
447
451
raise ValueError (msg )
448
- if not len (name ) == len (sample ):
452
+ if not len (name_sample ) == len (index_sample ):
449
453
try :
450
454
# If the original grouper was a tuple
451
- return self .indices [name ]
455
+ return [ self .indices [name ] for name in names ]
452
456
except KeyError :
453
457
# turns out it wasn't a tuple
454
458
msg = ("must supply a a same-length tuple to get_group"
455
459
" with multiple grouping keys" )
456
460
raise ValueError (msg )
457
461
458
- name = tuple ([ convert (n , k ) for n , k in zip (name ,sample ) ])
462
+ converters = [get_converter (s ) for s in index_sample ]
463
+ names = [tuple ([f (n ) for f , n in zip (converters , name )]) for name in names ]
459
464
460
465
else :
466
+ converter = get_converter (index_sample )
467
+ names = [converter (name ) for name in names ]
461
468
462
- name = convert ( name , sample )
469
+ return [ self . indices . get ( name , []) for name in names ]
463
470
464
- return self .indices [name ]
471
+ def _get_index (self , name ):
472
+ """ safe get index, translate keys for datelike to underlying repr """
473
+ return self ._get_indices ([name ])[0 ]
465
474
466
475
@property
467
476
def name (self ):
@@ -507,7 +516,7 @@ def _set_result_index_ordered(self, result):
507
516
508
517
# shortcut of we have an already ordered grouper
509
518
if not self .grouper .is_monotonic :
510
- index = Index (np .concatenate ([ indices . get ( v , []) for v in self .grouper .result_index ] ))
519
+ index = Index (np .concatenate (self . _get_indices ( self .grouper .result_index ) ))
511
520
result .index = index
512
521
result = result .sort_index ()
513
522
@@ -612,6 +621,9 @@ def get_group(self, name, obj=None):
612
621
obj = self ._selected_obj
613
622
614
623
inds = self ._get_index (name )
624
+ if not len (inds ):
625
+ raise KeyError (name )
626
+
615
627
return obj .take (inds , axis = self .axis , convert = False )
616
628
617
629
def __iter__ (self ):
@@ -2457,9 +2469,6 @@ def transform(self, func, *args, **kwargs):
2457
2469
2458
2470
wrapper = lambda x : func (x , * args , ** kwargs )
2459
2471
for i , (name , group ) in enumerate (self ):
2460
- if name not in self .indices :
2461
- continue
2462
-
2463
2472
object .__setattr__ (group , 'name' , name )
2464
2473
res = wrapper (group )
2465
2474
@@ -2474,7 +2483,7 @@ def transform(self, func, *args, **kwargs):
2474
2483
except :
2475
2484
pass
2476
2485
2477
- indexer = self .indices [ name ]
2486
+ indexer = self ._get_index ( name )
2478
2487
result [indexer ] = res
2479
2488
2480
2489
result = _possibly_downcast_to_dtype (result , dtype )
@@ -2528,11 +2537,8 @@ def true_and_notnull(x, *args, **kwargs):
2528
2537
return b and notnull (b )
2529
2538
2530
2539
try :
2531
- indices = []
2532
- for name , group in self :
2533
- if true_and_notnull (group ) and name in self .indices :
2534
- indices .append (self .indices [name ])
2535
-
2540
+ indices = [self ._get_index (name ) for name , group in self
2541
+ if true_and_notnull (group )]
2536
2542
except ValueError :
2537
2543
raise TypeError ("the filter must return a boolean result" )
2538
2544
except TypeError :
@@ -3060,8 +3066,8 @@ def transform(self, func, *args, **kwargs):
3060
3066
results = np .empty_like (obj .values , result .values .dtype )
3061
3067
indices = self .indices
3062
3068
for (name , group ), (i , row ) in zip (self , result .iterrows ()):
3063
- if name in indices :
3064
- indexer = indices [ name ]
3069
+ indexer = self . _get_index ( name )
3070
+ if len ( indexer ) > 0 :
3065
3071
results [indexer ] = np .tile (row .values ,len (indexer )).reshape (len (indexer ),- 1 )
3066
3072
3067
3073
counts = self .size ().fillna (0 ).values
@@ -3162,8 +3168,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
3162
3168
3163
3169
# interpret the result of the filter
3164
3170
if is_bool (res ) or (lib .isscalar (res ) and isnull (res )):
3165
- if res and notnull (res ) and name in self . indices :
3166
- indices .append (self .indices [ name ] )
3171
+ if res and notnull (res ):
3172
+ indices .append (self ._get_index ( name ) )
3167
3173
else :
3168
3174
# non scalars aren't allowed
3169
3175
raise TypeError ("filter function returned a %s, "
0 commit comments