Skip to content

Commit 441adbc

Browse files
Evan Wrightevanpw
Evan Wright
authored andcommitted
BUG: Filter/transform fail in some cases when multi-grouping with a datetime-like key (GH pandas-dev#10114)
1 parent f6c7d89 commit 441adbc

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Bug Fixes
6262

6363
- Bug in ``Categorical`` repr with ``display.width`` of ``None`` in Python 3 (:issue:`10087`)
6464

65+
- Bug in ``filter`` (regression from 0.16.0) and ``transform`` when grouping on multiple keys, one of which is datetime-like (:issue:`10114`)
6566

6667
- Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`)
6768
- Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`)

pandas/core/groupby.py

+36-30
Original file line numberDiff line numberDiff line change
@@ -413,42 +413,54 @@ def indices(self):
413413
""" dict {group name -> group indices} """
414414
return self.grouper.indices
415415

416-
def _get_index(self, name):
417-
""" safe get index, translate keys for datelike to underlying repr """
416+
def _get_indices(self, names, raise_on_missing=False):
417+
""" safe get multiple indices, translate keys for datelike to underlying repr """
418418

419-
def convert(key, s):
419+
def get_converter(s):
420420
# possibly convert to they actual key types
421421
# in the indices, could be a Timestamp or a np.datetime64
422-
423422
if isinstance(s, (Timestamp,datetime.datetime)):
424-
return Timestamp(key)
423+
return lambda key: Timestamp(key)
425424
elif isinstance(s, np.datetime64):
426-
return Timestamp(key).asm8
427-
return key
425+
return lambda key: Timestamp(key).asm8
426+
else:
427+
return lambda key: key
428428

429-
sample = next(iter(self.indices))
430-
if isinstance(sample, tuple):
431-
if not isinstance(name, tuple):
429+
if len(names) == 0:
430+
return []
431+
432+
index_sample = next(iter(self.indices))
433+
name_sample = names[0]
434+
if isinstance(index_sample, tuple):
435+
if not isinstance(name_sample, tuple):
432436
msg = ("must supply a tuple to get_group with multiple"
433437
" grouping keys")
434438
raise ValueError(msg)
435-
if not len(name) == len(sample):
439+
if not len(name_sample) == len(index_sample):
436440
try:
437441
# If the original grouper was a tuple
438-
return self.indices[name]
442+
return [self.indices[name] for name in names]
439443
except KeyError:
440444
# turns out it wasn't a tuple
441445
msg = ("must supply a a same-length tuple to get_group"
442446
" with multiple grouping keys")
443447
raise ValueError(msg)
444448

445-
name = tuple([ convert(n, k) for n, k in zip(name,sample) ])
449+
converters = [get_converter(s) for s in index_sample]
450+
names = [tuple([f(n) for f, n in zip(converters, name)]) for name in names]
446451

447452
else:
453+
converter = get_converter(index_sample)
454+
names = [converter(name) for name in names]
448455

449-
name = convert(name, sample)
456+
if raise_on_missing:
457+
return [self.indices[name] for name in names]
458+
else:
459+
return [self.indices.get(name, []) for name in names]
450460

451-
return self.indices[name]
461+
def _get_index(self, name, raise_on_missing=False):
462+
""" safe get index, translate keys for datelike to underlying repr """
463+
return self._get_indices([name], raise_on_missing)[0]
452464

453465
@property
454466
def name(self):
@@ -494,7 +506,7 @@ def _set_result_index_ordered(self, result):
494506

495507
# shortcut of we have an already ordered grouper
496508
if not self.grouper.is_monotonic:
497-
index = Index(np.concatenate([ indices.get(v, []) for v in self.grouper.result_index]))
509+
index = Index(np.concatenate(self._get_indices(self.grouper.result_index)))
498510
result.index = index
499511
result = result.sort_index()
500512

@@ -598,7 +610,7 @@ def get_group(self, name, obj=None):
598610
if obj is None:
599611
obj = self._selected_obj
600612

601-
inds = self._get_index(name)
613+
inds = self._get_index(name, raise_on_missing=True)
602614
return obj.take(inds, axis=self.axis, convert=False)
603615

604616
def __iter__(self):
@@ -2445,9 +2457,6 @@ def transform(self, func, *args, **kwargs):
24452457

24462458
wrapper = lambda x: func(x, *args, **kwargs)
24472459
for i, (name, group) in enumerate(self):
2448-
if name not in self.indices:
2449-
continue
2450-
24512460
object.__setattr__(group, 'name', name)
24522461
res = wrapper(group)
24532462

@@ -2462,7 +2471,7 @@ def transform(self, func, *args, **kwargs):
24622471
except:
24632472
pass
24642473

2465-
indexer = self.indices[name]
2474+
indexer = self._get_index(name)
24662475
result[indexer] = res
24672476

24682477
result = _possibly_downcast_to_dtype(result, dtype)
@@ -2516,11 +2525,8 @@ def true_and_notnull(x, *args, **kwargs):
25162525
return b and notnull(b)
25172526

25182527
try:
2519-
indices = []
2520-
for name, group in self:
2521-
if true_and_notnull(group) and name in self.indices:
2522-
indices.append(self.indices[name])
2523-
2528+
indices = [self._get_index(name) if true_and_notnull(group)
2529+
for name, group in self]
25242530
except ValueError:
25252531
raise TypeError("the filter must return a boolean result")
25262532
except TypeError:
@@ -3040,8 +3046,8 @@ def transform(self, func, *args, **kwargs):
30403046
results = np.empty_like(obj.values, result.values.dtype)
30413047
indices = self.indices
30423048
for (name, group), (i, row) in zip(self, result.iterrows()):
3043-
if name in indices:
3044-
indexer = indices[name]
3049+
indexer = self._get_index(name)
3050+
if len(indexer) > 0:
30453051
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
30463052

30473053
counts = self.size().fillna(0).values
@@ -3141,8 +3147,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
31413147

31423148
# interpret the result of the filter
31433149
if is_bool(res) or (lib.isscalar(res) and isnull(res)):
3144-
if res and notnull(res) and name in self.indices:
3145-
indices.append(self.indices[name])
3150+
if res and notnull(res):
3151+
indices.append(self._get_index(name))
31463152
else:
31473153
# non scalars aren't allowed
31483154
raise TypeError("filter function returned a %s, "

pandas/tests/test_groupby.py

+26
Original file line numberDiff line numberDiff line change
@@ -4335,6 +4335,32 @@ def test_filter_maintains_ordering(self):
43354335
expected = s.iloc[[1, 2, 4, 7]]
43364336
assert_series_equal(actual, expected)
43374337

4338+
def test_filter_multiple_timestamp(self):
4339+
# GH 10114
4340+
df = DataFrame({'A' : np.arange(5),
4341+
'B' : ['foo','bar','foo','bar','bar'],
4342+
'C' : Timestamp('20130101') })
4343+
4344+
grouped = df.groupby(['B', 'C'])
4345+
4346+
result = grouped['A'].filter(lambda x: True)
4347+
assert_series_equal(df['A'], result)
4348+
4349+
result = grouped['A'].transform(len)
4350+
expected = Series([2, 3, 2, 3, 3], name='A')
4351+
assert_series_equal(result, expected)
4352+
4353+
result = grouped.filter(lambda x: True)
4354+
assert_frame_equal(df, result)
4355+
4356+
result = grouped.transform('sum')
4357+
expected = DataFrame({'A' : [2, 8, 2, 8, 8]})
4358+
assert_frame_equal(result, expected)
4359+
4360+
result = grouped.transform(len)
4361+
expected = DataFrame({'A' : [2, 3, 2, 3, 3]})
4362+
assert_frame_equal(result, expected)
4363+
43384364
def test_filter_and_transform_with_non_unique_int_index(self):
43394365
# GH4620
43404366
index = [1, 1, 1, 2, 1, 1, 0, 1]

0 commit comments

Comments
 (0)