Skip to content

Commit 5efb281

Browse files
committed
BUG: transform and filter misbehave when grouping on categorical data (GH 9921)
1 parent e6c4f76 commit 5efb281

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

doc/source/whatsnew/v0.16.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,5 @@ Bug Fixes
244244

245245

246246
- Bug in hiding ticklabels with subplots and shared axes when adding a new plot to an existing grid of axes (:issue:`9158`)
247+
- Bug in ``transform`` and ``filter`` when grouping on a categorical variable (:issue:`9921`)
248+

pandas/core/groupby.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def _set_result_index_ordered(self, result):
491491

492492
# shortcut of we have an already ordered grouper
493493
if not self.grouper.is_monotonic:
494-
index = Index(np.concatenate([ indices[v] for v in self.grouper.result_index ]))
494+
index = Index(np.concatenate([ indices[v] for v in self.grouper.result_index if v in indices]))
495495
result.index = index
496496
result = result.sort_index()
497497

@@ -2436,6 +2436,8 @@ def transform(self, func, *args, **kwargs):
24362436

24372437
wrapper = lambda x: func(x, *args, **kwargs)
24382438
for i, (name, group) in enumerate(self):
2439+
if name not in self.indices:
2440+
continue
24392441

24402442
object.__setattr__(group, 'name', name)
24412443
res = wrapper(group)
@@ -2451,7 +2453,7 @@ def transform(self, func, *args, **kwargs):
24512453
except:
24522454
pass
24532455

2454-
indexer = self._get_index(name)
2456+
indexer = self.indices[name]
24552457
result[indexer] = res
24562458

24572459
result = _possibly_downcast_to_dtype(result, dtype)
@@ -2465,9 +2467,12 @@ def _transform_fast(self, func):
24652467
"""
24662468
if isinstance(func, compat.string_types):
24672469
func = getattr(self,func)
2470+
24682471
values = func().values
2469-
counts = self.size().values
2472+
counts = self.size().fillna(0).values
24702473
values = np.repeat(values, com._ensure_platform_int(counts))
2474+
if any(counts == 0):
2475+
values = self._try_cast(values, self._selected_obj)
24712476

24722477
return self._set_result_index_ordered(Series(values))
24732478

@@ -2502,8 +2507,11 @@ def true_and_notnull(x, *args, **kwargs):
25022507
return b and notnull(b)
25032508

25042509
try:
2505-
indices = [self._get_index(name) if true_and_notnull(group) else []
2506-
for name, group in self]
2510+
indices = []
2511+
for name, group in self:
2512+
if true_and_notnull(group) and name in self.indices:
2513+
indices.append(self.indices[name])
2514+
25072515
except ValueError:
25082516
raise TypeError("the filter must return a boolean result")
25092517
except TypeError:
@@ -3015,24 +3023,18 @@ def transform(self, func, *args, **kwargs):
30153023
if not result.columns.equals(obj.columns):
30163024
return self._transform_general(func, *args, **kwargs)
30173025

3018-
# a grouped that doesn't preserve the index, remap index based on the grouper
3019-
# and broadcast it
3020-
if ((not isinstance(obj.index,MultiIndex) and
3021-
type(result.index) != type(obj.index)) or
3022-
len(result.index) != len(obj.index)):
3023-
results = np.empty_like(obj.values, result.values.dtype)
3024-
indices = self.indices
3025-
for (name, group), (i, row) in zip(self, result.iterrows()):
3026+
results = np.empty_like(obj.values, result.values.dtype)
3027+
indices = self.indices
3028+
for (name, group), (i, row) in zip(self, result.iterrows()):
3029+
if name in indices:
30263030
indexer = indices[name]
30273031
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
3028-
return DataFrame(results,columns=result.columns,index=obj.index).convert_objects()
30293032

3030-
# we can merge the result in
3031-
# GH 7383
3032-
names = result.columns
3033-
result = obj.merge(result, how='outer', left_index=True, right_index=True).iloc[:,-result.shape[1]:]
3034-
result.columns = names
3035-
return result
3033+
counts = self.size().fillna(0).values
3034+
if any(counts == 0):
3035+
results = self._try_cast(results, obj[result.columns])
3036+
3037+
return DataFrame(results,columns=result.columns,index=obj.index).convert_objects()
30363038

30373039
def _define_paths(self, func, *args, **kwargs):
30383040
if isinstance(func, compat.string_types):
@@ -3126,8 +3128,8 @@ def filter(self, func, dropna=True, *args, **kwargs):
31263128
# interpret the result of the filter
31273129
if (isinstance(res, (bool, np.bool_)) or
31283130
np.isscalar(res) and isnull(res)):
3129-
if res and notnull(res):
3130-
indices.append(self._get_index(name))
3131+
if res and notnull(res) and name in self.indices:
3132+
indices.append(self.indices[name])
31313133
else:
31323134
# non scalars aren't allowed
31333135
raise TypeError("filter function returned a %s, "

pandas/tests/test_categorical.py

+21
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,27 @@ def f(x):
18201820
expected['person_name'] = expected['person_name'].astype('object')
18211821
tm.assert_frame_equal(result, expected)
18221822

1823+
# GH 9921
1824+
# Monotonic
1825+
df = DataFrame({"a": [5, 15, 25]})
1826+
c = pd.cut(df.a, bins=[0,10,20,30,40])
1827+
tm.assert_series_equal(df.a.groupby(c).transform(sum), df['a'])
1828+
tm.assert_series_equal(df.a.groupby(c).transform(lambda xs: np.sum(xs)), df['a'])
1829+
tm.assert_frame_equal(df.groupby(c).transform(sum), df[['a']])
1830+
tm.assert_frame_equal(df.groupby(c).transform(lambda xs: np.max(xs)), df[['a']])
1831+
1832+
# Filter
1833+
tm.assert_series_equal(df.a.groupby(c).filter(np.all), df['a'])
1834+
tm.assert_frame_equal(df.groupby(c).filter(np.all), df)
1835+
1836+
# Non-monotonic
1837+
df = DataFrame({"a": [5, 15, 25, -5]})
1838+
c = pd.cut(df.a, bins=[-10, 0,10,20,30,40])
1839+
tm.assert_series_equal(df.a.groupby(c).transform(sum), df['a'])
1840+
tm.assert_series_equal(df.a.groupby(c).transform(lambda xs: np.sum(xs)), df['a'])
1841+
tm.assert_frame_equal(df.groupby(c).transform(sum), df[['a']])
1842+
tm.assert_frame_equal(df.groupby(c).transform(lambda xs: np.sum(xs)), df[['a']])
1843+
18231844
def test_pivot_table(self):
18241845

18251846
raw_cat1 = Categorical(["a","a","b","b"], categories=["a","b","z"], ordered=True)

0 commit comments

Comments
 (0)