Skip to content

Commit a517cd4

Browse files
Evan Wrightevanpw
Evan Wright
authored andcommitted
BUG: Filter/transform fail in some cases when multi-grouping with a datetime-like key (GH pandas-dev#10114)
1 parent 676cb95 commit a517cd4

File tree

3 files changed

+64
-30
lines changed

3 files changed

+64
-30
lines changed

doc/source/whatsnew/v0.16.2.txt

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Bug Fixes
6262
- Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`)
6363
- Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`)
6464

65+
- Bug in ``filter`` (regression from 0.16.0) and ``transform`` when grouping on multiple keys, one of which is datetime-like (:issue:`10114`)
66+
6567

6668
- Bug in getting timezone data with ``dateutil`` on various platforms ( :issue:`9059`, :issue:`8639`, :issue:`9663`, :issue:`10121`)
6769
- Bug in display datetimes with mixed frequencies uniformly; display 'ms' datetimes to the proper precision. (:issue:`10170`)

pandas/core/groupby.py

100644100755
+36-30
Original file line numberDiff line numberDiff line change
@@ -413,46 +413,55 @@ def indices(self):
413413
""" dict {group name -> group indices} """
414414
return self.grouper.indices
415415

416-
def _get_index(self, name):
417-
""" safe get index, translate keys for datelike to underlying repr """
416+
def _get_indices(self, names):
417+
""" safe get multiple indices, translate keys for datelike to underlying repr """
418418

419-
def convert(key, s):
419+
def get_converter(s):
420420
# possibly convert to they actual key types
421421
# in the indices, could be a Timestamp or a np.datetime64
422-
423422
if isinstance(s, (Timestamp,datetime.datetime)):
424-
return Timestamp(key)
423+
return lambda key: Timestamp(key)
425424
elif isinstance(s, np.datetime64):
426-
return Timestamp(key).asm8
427-
return key
425+
return lambda key: Timestamp(key).asm8
426+
else:
427+
return lambda key: key
428+
429+
if len(names) == 0:
430+
return []
428431

429432
if len(self.indices) > 0:
430-
sample = next(iter(self.indices))
433+
index_sample = next(iter(self.indices))
431434
else:
432-
sample = None # Dummy sample
435+
index_sample = None # Dummy sample
433436

434-
if isinstance(sample, tuple):
435-
if not isinstance(name, tuple):
437+
name_sample = names[0]
438+
if isinstance(index_sample, tuple):
439+
if not isinstance(name_sample, tuple):
436440
msg = ("must supply a tuple to get_group with multiple"
437441
" grouping keys")
438442
raise ValueError(msg)
439-
if not len(name) == len(sample):
443+
if not len(name_sample) == len(index_sample):
440444
try:
441445
# If the original grouper was a tuple
442-
return self.indices[name]
446+
return [self.indices[name] for name in names]
443447
except KeyError:
444448
# turns out it wasn't a tuple
445449
msg = ("must supply a a same-length tuple to get_group"
446450
" with multiple grouping keys")
447451
raise ValueError(msg)
448452

449-
name = tuple([ convert(n, k) for n, k in zip(name,sample) ])
453+
converters = [get_converter(s) for s in index_sample]
454+
names = [tuple([f(n) for f, n in zip(converters, name)]) for name in names]
450455

451456
else:
457+
converter = get_converter(index_sample)
458+
names = [converter(name) for name in names]
452459

453-
name = convert(name, sample)
460+
return [self.indices.get(name, []) for name in names]
454461

455-
return self.indices[name]
462+
def _get_index(self, name):
463+
""" safe get index, translate keys for datelike to underlying repr """
464+
return self._get_indices([name])[0]
456465

457466
@property
458467
def name(self):
@@ -498,7 +507,7 @@ def _set_result_index_ordered(self, result):
498507

499508
# shortcut of we have an already ordered grouper
500509
if not self.grouper.is_monotonic:
501-
index = Index(np.concatenate([ indices.get(v, []) for v in self.grouper.result_index]))
510+
index = Index(np.concatenate(self._get_indices(self.grouper.result_index)))
502511
result.index = index
503512
result = result.sort_index()
504513

@@ -603,6 +612,9 @@ def get_group(self, name, obj=None):
603612
obj = self._selected_obj
604613

605614
inds = self._get_index(name)
615+
if not len(inds):
616+
raise KeyError(name)
617+
606618
return obj.take(inds, axis=self.axis, convert=False)
607619

608620
def __iter__(self):
@@ -2449,9 +2461,6 @@ def transform(self, func, *args, **kwargs):
24492461

24502462
wrapper = lambda x: func(x, *args, **kwargs)
24512463
for i, (name, group) in enumerate(self):
2452-
if name not in self.indices:
2453-
continue
2454-
24552464
object.__setattr__(group, 'name', name)
24562465
res = wrapper(group)
24572466

@@ -2466,7 +2475,7 @@ def transform(self, func, *args, **kwargs):
24662475
except:
24672476
pass
24682477

2469-
indexer = self.indices[name]
2478+
indexer = self._get_index(name)
24702479
result[indexer] = res
24712480

24722481
result = _possibly_downcast_to_dtype(result, dtype)
@@ -2520,11 +2529,8 @@ def true_and_notnull(x, *args, **kwargs):
25202529
return b and notnull(b)
25212530

25222531
try:
2523-
indices = []
2524-
for name, group in self:
2525-
if true_and_notnull(group) and name in self.indices:
2526-
indices.append(self.indices[name])
2527-
2532+
indices = [self._get_index(name) for name, group in self
2533+
if true_and_notnull(group)]
25282534
except ValueError:
25292535
raise TypeError("the filter must return a boolean result")
25302536
except TypeError:
@@ -3044,8 +3050,8 @@ def transform(self, func, *args, **kwargs):
30443050
results = np.empty_like(obj.values, result.values.dtype)
30453051
indices = self.indices
30463052
for (name, group), (i, row) in zip(self, result.iterrows()):
3047-
if name in indices:
3048-
indexer = indices[name]
3053+
indexer = self._get_index(name)
3054+
if len(indexer) > 0:
30493055
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
30503056

30513057
counts = self.size().fillna(0).values
@@ -3145,8 +3151,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
31453151

31463152
# interpret the result of the filter
31473153
if is_bool(res) or (lib.isscalar(res) and isnull(res)):
3148-
if res and notnull(res) and name in self.indices:
3149-
indices.append(self.indices[name])
3154+
if res and notnull(res):
3155+
indices.append(self._get_index(name))
31503156
else:
31513157
# non scalars aren't allowed
31523158
raise TypeError("filter function returned a %s, "

pandas/tests/test_groupby.py

+26
Original file line numberDiff line numberDiff line change
@@ -4377,6 +4377,32 @@ def test_filter_maintains_ordering(self):
43774377
expected = s.iloc[[1, 2, 4, 7]]
43784378
assert_series_equal(actual, expected)
43794379

4380+
def test_filter_multiple_timestamp(self):
4381+
# GH 10114
4382+
df = DataFrame({'A' : np.arange(5),
4383+
'B' : ['foo','bar','foo','bar','bar'],
4384+
'C' : Timestamp('20130101') })
4385+
4386+
grouped = df.groupby(['B', 'C'])
4387+
4388+
result = grouped['A'].filter(lambda x: True)
4389+
assert_series_equal(df['A'], result)
4390+
4391+
result = grouped['A'].transform(len)
4392+
expected = Series([2, 3, 2, 3, 3], name='A')
4393+
assert_series_equal(result, expected)
4394+
4395+
result = grouped.filter(lambda x: True)
4396+
assert_frame_equal(df, result)
4397+
4398+
result = grouped.transform('sum')
4399+
expected = DataFrame({'A' : [2, 8, 2, 8, 8]})
4400+
assert_frame_equal(result, expected)
4401+
4402+
result = grouped.transform(len)
4403+
expected = DataFrame({'A' : [2, 3, 2, 3, 3]})
4404+
assert_frame_equal(result, expected)
4405+
43804406
def test_filter_and_transform_with_non_unique_int_index(self):
43814407
# GH4620
43824408
index = [1, 1, 1, 2, 1, 1, 0, 1]

0 commit comments

Comments
 (0)