@@ -413,46 +413,55 @@ def indices(self):
413
413
""" dict {group name -> group indices} """
414
414
return self .grouper .indices
415
415
416
- def _get_index (self , name ):
417
- """ safe get index , translate keys for datelike to underlying repr """
416
+ def _get_indices (self , names ):
417
+ """ safe get multiple indices , translate keys for datelike to underlying repr """
418
418
419
- def convert ( key , s ):
419
+ def get_converter ( s ):
420
420
# possibly convert to they actual key types
421
421
# in the indices, could be a Timestamp or a np.datetime64
422
-
423
422
if isinstance (s , (Timestamp ,datetime .datetime )):
424
- return Timestamp (key )
423
+ return lambda key : Timestamp (key )
425
424
elif isinstance (s , np .datetime64 ):
426
- return Timestamp (key ).asm8
427
- return key
425
+ return lambda key : Timestamp (key ).asm8
426
+ else :
427
+ return lambda key : key
428
+
429
+ if len (names ) == 0 :
430
+ return []
428
431
429
432
if len (self .indices ) > 0 :
430
- sample = next (iter (self .indices ))
433
+ index_sample = next (iter (self .indices ))
431
434
else :
432
- sample = None # Dummy sample
435
+ index_sample = None # Dummy sample
433
436
434
- if isinstance (sample , tuple ):
435
- if not isinstance (name , tuple ):
437
+ name_sample = names [0 ]
438
+ if isinstance (index_sample , tuple ):
439
+ if not isinstance (name_sample , tuple ):
436
440
msg = ("must supply a tuple to get_group with multiple"
437
441
" grouping keys" )
438
442
raise ValueError (msg )
439
- if not len (name ) == len (sample ):
443
+ if not len (name_sample ) == len (index_sample ):
440
444
try :
441
445
# If the original grouper was a tuple
442
- return self .indices [name ]
446
+ return [ self .indices [name ] for name in names ]
443
447
except KeyError :
444
448
# turns out it wasn't a tuple
445
449
msg = ("must supply a a same-length tuple to get_group"
446
450
" with multiple grouping keys" )
447
451
raise ValueError (msg )
448
452
449
- name = tuple ([ convert (n , k ) for n , k in zip (name ,sample ) ])
453
+ converters = [get_converter (s ) for s in index_sample ]
454
+ names = [tuple ([f (n ) for f , n in zip (converters , name )]) for name in names ]
450
455
451
456
else :
457
+ converter = get_converter (index_sample )
458
+ names = [converter (name ) for name in names ]
452
459
453
- name = convert ( name , sample )
460
+ return [ self . indices . get ( name , []) for name in names ]
454
461
455
- return self .indices [name ]
462
+ def _get_index (self , name ):
463
+ """ safe get index, translate keys for datelike to underlying repr """
464
+ return self ._get_indices ([name ])[0 ]
456
465
457
466
@property
458
467
def name (self ):
@@ -498,7 +507,7 @@ def _set_result_index_ordered(self, result):
498
507
499
508
# shortcut of we have an already ordered grouper
500
509
if not self .grouper .is_monotonic :
501
- index = Index (np .concatenate ([ indices . get ( v , []) for v in self .grouper .result_index ] ))
510
+ index = Index (np .concatenate (self . _get_indices ( self .grouper .result_index ) ))
502
511
result .index = index
503
512
result = result .sort_index ()
504
513
@@ -603,6 +612,9 @@ def get_group(self, name, obj=None):
603
612
obj = self ._selected_obj
604
613
605
614
inds = self ._get_index (name )
615
+ if not len (inds ):
616
+ raise KeyError (name )
617
+
606
618
return obj .take (inds , axis = self .axis , convert = False )
607
619
608
620
def __iter__ (self ):
@@ -2449,9 +2461,6 @@ def transform(self, func, *args, **kwargs):
2449
2461
2450
2462
wrapper = lambda x : func (x , * args , ** kwargs )
2451
2463
for i , (name , group ) in enumerate (self ):
2452
- if name not in self .indices :
2453
- continue
2454
-
2455
2464
object .__setattr__ (group , 'name' , name )
2456
2465
res = wrapper (group )
2457
2466
@@ -2466,7 +2475,7 @@ def transform(self, func, *args, **kwargs):
2466
2475
except :
2467
2476
pass
2468
2477
2469
- indexer = self .indices [ name ]
2478
+ indexer = self ._get_index ( name )
2470
2479
result [indexer ] = res
2471
2480
2472
2481
result = _possibly_downcast_to_dtype (result , dtype )
@@ -2520,11 +2529,8 @@ def true_and_notnull(x, *args, **kwargs):
2520
2529
return b and notnull (b )
2521
2530
2522
2531
try :
2523
- indices = []
2524
- for name , group in self :
2525
- if true_and_notnull (group ) and name in self .indices :
2526
- indices .append (self .indices [name ])
2527
-
2532
+ indices = [self ._get_index (name ) for name , group in self
2533
+ if true_and_notnull (group )]
2528
2534
except ValueError :
2529
2535
raise TypeError ("the filter must return a boolean result" )
2530
2536
except TypeError :
@@ -3044,8 +3050,8 @@ def transform(self, func, *args, **kwargs):
3044
3050
results = np .empty_like (obj .values , result .values .dtype )
3045
3051
indices = self .indices
3046
3052
for (name , group ), (i , row ) in zip (self , result .iterrows ()):
3047
- if name in indices :
3048
- indexer = indices [ name ]
3053
+ indexer = self . _get_index ( name )
3054
+ if len ( indexer ) > 0 :
3049
3055
results [indexer ] = np .tile (row .values ,len (indexer )).reshape (len (indexer ),- 1 )
3050
3056
3051
3057
counts = self .size ().fillna (0 ).values
@@ -3145,8 +3151,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
3145
3151
3146
3152
# interpret the result of the filter
3147
3153
if is_bool (res ) or (lib .isscalar (res ) and isnull (res )):
3148
- if res and notnull (res ) and name in self . indices :
3149
- indices .append (self .indices [ name ] )
3154
+ if res and notnull (res ):
3155
+ indices .append (self ._get_index ( name ) )
3150
3156
else :
3151
3157
# non scalars aren't allowed
3152
3158
raise TypeError ("filter function returned a %s, "
0 commit comments