Skip to content

Commit c84ab54

Browse files
author
Evan Wright
committed
BUG: Filter/transform fail in some cases when multi-grouping with a datetime-like key (GH #10114)
1 parent 92da9ed commit c84ab54

File tree

3 files changed

+63
-31
lines changed

3 files changed

+63
-31
lines changed

doc/source/whatsnew/v0.17.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ Bug Fixes
390390
- Bug in ``pd.get_dummies`` with `sparse=True` not returning ``SparseDataFrame`` (:issue:`10531`)
391391
- Bug in ``Index`` subtypes (such as ``PeriodIndex``) not returning their own type for ``.drop`` and ``.insert`` methods (:issue:`10620`)
392392

393-
393+
- Bug in ``filter`` (regression from 0.16.0) and ``transform`` when grouping on multiple keys, one of which is datetime-like (:issue:`10114`)
394394

395395

396396

pandas/core/groupby.py

+36-30
Original file line numberDiff line numberDiff line change
@@ -422,46 +422,55 @@ def indices(self):
422422
""" dict {group name -> group indices} """
423423
return self.grouper.indices
424424

425-
def _get_index(self, name):
426-
""" safe get index, translate keys for datelike to underlying repr """
425+
def _get_indices(self, names):
426+
""" safe get multiple indices, translate keys for datelike to underlying repr """
427427

428-
def convert(key, s):
428+
def get_converter(s):
429429
# possibly convert to they actual key types
430430
# in the indices, could be a Timestamp or a np.datetime64
431-
432431
if isinstance(s, (Timestamp,datetime.datetime)):
433-
return Timestamp(key)
432+
return lambda key: Timestamp(key)
434433
elif isinstance(s, np.datetime64):
435-
return Timestamp(key).asm8
436-
return key
434+
return lambda key: Timestamp(key).asm8
435+
else:
436+
return lambda key: key
437+
438+
if len(names) == 0:
439+
return []
437440

438441
if len(self.indices) > 0:
439-
sample = next(iter(self.indices))
442+
index_sample = next(iter(self.indices))
440443
else:
441-
sample = None # Dummy sample
444+
index_sample = None # Dummy sample
442445

443-
if isinstance(sample, tuple):
444-
if not isinstance(name, tuple):
446+
name_sample = names[0]
447+
if isinstance(index_sample, tuple):
448+
if not isinstance(name_sample, tuple):
445449
msg = ("must supply a tuple to get_group with multiple"
446450
" grouping keys")
447451
raise ValueError(msg)
448-
if not len(name) == len(sample):
452+
if not len(name_sample) == len(index_sample):
449453
try:
450454
# If the original grouper was a tuple
451-
return self.indices[name]
455+
return [self.indices[name] for name in names]
452456
except KeyError:
453457
# turns out it wasn't a tuple
454458
msg = ("must supply a a same-length tuple to get_group"
455459
" with multiple grouping keys")
456460
raise ValueError(msg)
457461

458-
name = tuple([ convert(n, k) for n, k in zip(name,sample) ])
462+
converters = [get_converter(s) for s in index_sample]
463+
names = [tuple([f(n) for f, n in zip(converters, name)]) for name in names]
459464

460465
else:
466+
converter = get_converter(index_sample)
467+
names = [converter(name) for name in names]
461468

462-
name = convert(name, sample)
469+
return [self.indices.get(name, []) for name in names]
463470

464-
return self.indices[name]
471+
def _get_index(self, name):
472+
""" safe get index, translate keys for datelike to underlying repr """
473+
return self._get_indices([name])[0]
465474

466475
@property
467476
def name(self):
@@ -507,7 +516,7 @@ def _set_result_index_ordered(self, result):
507516

508517
# shortcut of we have an already ordered grouper
509518
if not self.grouper.is_monotonic:
510-
index = Index(np.concatenate([ indices.get(v, []) for v in self.grouper.result_index]))
519+
index = Index(np.concatenate(self._get_indices(self.grouper.result_index)))
511520
result.index = index
512521
result = result.sort_index()
513522

@@ -612,6 +621,9 @@ def get_group(self, name, obj=None):
612621
obj = self._selected_obj
613622

614623
inds = self._get_index(name)
624+
if not len(inds):
625+
raise KeyError(name)
626+
615627
return obj.take(inds, axis=self.axis, convert=False)
616628

617629
def __iter__(self):
@@ -2457,9 +2469,6 @@ def transform(self, func, *args, **kwargs):
24572469

24582470
wrapper = lambda x: func(x, *args, **kwargs)
24592471
for i, (name, group) in enumerate(self):
2460-
if name not in self.indices:
2461-
continue
2462-
24632472
object.__setattr__(group, 'name', name)
24642473
res = wrapper(group)
24652474

@@ -2474,7 +2483,7 @@ def transform(self, func, *args, **kwargs):
24742483
except:
24752484
pass
24762485

2477-
indexer = self.indices[name]
2486+
indexer = self._get_index(name)
24782487
result[indexer] = res
24792488

24802489
result = _possibly_downcast_to_dtype(result, dtype)
@@ -2528,11 +2537,8 @@ def true_and_notnull(x, *args, **kwargs):
25282537
return b and notnull(b)
25292538

25302539
try:
2531-
indices = []
2532-
for name, group in self:
2533-
if true_and_notnull(group) and name in self.indices:
2534-
indices.append(self.indices[name])
2535-
2540+
indices = [self._get_index(name) for name, group in self
2541+
if true_and_notnull(group)]
25362542
except ValueError:
25372543
raise TypeError("the filter must return a boolean result")
25382544
except TypeError:
@@ -3060,8 +3066,8 @@ def transform(self, func, *args, **kwargs):
30603066
results = np.empty_like(obj.values, result.values.dtype)
30613067
indices = self.indices
30623068
for (name, group), (i, row) in zip(self, result.iterrows()):
3063-
if name in indices:
3064-
indexer = indices[name]
3069+
indexer = self._get_index(name)
3070+
if len(indexer) > 0:
30653071
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
30663072

30673073
counts = self.size().fillna(0).values
@@ -3162,8 +3168,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
31623168

31633169
# interpret the result of the filter
31643170
if is_bool(res) or (lib.isscalar(res) and isnull(res)):
3165-
if res and notnull(res) and name in self.indices:
3166-
indices.append(self.indices[name])
3171+
if res and notnull(res):
3172+
indices.append(self._get_index(name))
31673173
else:
31683174
# non scalars aren't allowed
31693175
raise TypeError("filter function returned a %s, "

pandas/tests/test_groupby.py

+26
Original file line numberDiff line numberDiff line change
@@ -4477,6 +4477,32 @@ def test_filter_maintains_ordering(self):
44774477
expected = s.iloc[[1, 2, 4, 7]]
44784478
assert_series_equal(actual, expected)
44794479

4480+
def test_filter_multiple_timestamp(self):
4481+
# GH 10114
4482+
df = DataFrame({'A' : np.arange(5),
4483+
'B' : ['foo','bar','foo','bar','bar'],
4484+
'C' : Timestamp('20130101') })
4485+
4486+
grouped = df.groupby(['B', 'C'])
4487+
4488+
result = grouped['A'].filter(lambda x: True)
4489+
assert_series_equal(df['A'], result)
4490+
4491+
result = grouped['A'].transform(len)
4492+
expected = Series([2, 3, 2, 3, 3], name='A')
4493+
assert_series_equal(result, expected)
4494+
4495+
result = grouped.filter(lambda x: True)
4496+
assert_frame_equal(df, result)
4497+
4498+
result = grouped.transform('sum')
4499+
expected = DataFrame({'A' : [2, 8, 2, 8, 8]})
4500+
assert_frame_equal(result, expected)
4501+
4502+
result = grouped.transform(len)
4503+
expected = DataFrame({'A' : [2, 3, 2, 3, 3]})
4504+
assert_frame_equal(result, expected)
4505+
44804506
def test_filter_and_transform_with_non_unique_int_index(self):
44814507
# GH4620
44824508
index = [1, 1, 1, 2, 1, 1, 0, 1]

0 commit comments

Comments
 (0)