Skip to content

Commit c9d1ef9

Browse files
committed
Merge pull request #9994 from evanpw/issue_9921
BUG: transform and filter misbehave when grouping on categorical data
2 parents a93c547 + 3d73550 commit c9d1ef9

File tree

4 files changed

+56
-26
lines changed

4 files changed

+56
-26
lines changed

doc/source/whatsnew/v0.16.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,5 @@ Bug Fixes
252252

253253

254254
- Bug in hiding ticklabels with subplots and shared axes when adding a new plot to an existing grid of axes (:issue:`9158`)
255+
- Bug in ``transform`` and ``filter`` when grouping on a categorical variable (:issue:`9921`)
256+
- Bug in ``transform`` when groups are equal in number and dtype to the input index (:issue:`9700`)

pandas/core/groupby.py

+27-26
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
notnull, _DATELIKE_DTYPES, is_numeric_dtype,
2626
is_timedelta64_dtype, is_datetime64_dtype,
2727
is_categorical_dtype, _values_from_object,
28-
is_datetime_or_timedelta_dtype, is_bool_dtype,
29-
AbstractMethodError)
28+
is_datetime_or_timedelta_dtype, is_bool,
29+
is_bool_dtype, AbstractMethodError)
3030
from pandas.core.config import option_context
3131
import pandas.lib as lib
3232
from pandas.lib import Timestamp
@@ -491,7 +491,7 @@ def _set_result_index_ordered(self, result):
491491

492492
# shortcut of we have an already ordered grouper
493493
if not self.grouper.is_monotonic:
494-
index = Index(np.concatenate([ indices[v] for v in self.grouper.result_index ]))
494+
index = Index(np.concatenate([ indices.get(v, []) for v in self.grouper.result_index]))
495495
result.index = index
496496
result = result.sort_index()
497497

@@ -2436,6 +2436,8 @@ def transform(self, func, *args, **kwargs):
24362436

24372437
wrapper = lambda x: func(x, *args, **kwargs)
24382438
for i, (name, group) in enumerate(self):
2439+
if name not in self.indices:
2440+
continue
24392441

24402442
object.__setattr__(group, 'name', name)
24412443
res = wrapper(group)
@@ -2451,7 +2453,7 @@ def transform(self, func, *args, **kwargs):
24512453
except:
24522454
pass
24532455

2454-
indexer = self._get_index(name)
2456+
indexer = self.indices[name]
24552457
result[indexer] = res
24562458

24572459
result = _possibly_downcast_to_dtype(result, dtype)
@@ -2465,9 +2467,12 @@ def _transform_fast(self, func):
24652467
"""
24662468
if isinstance(func, compat.string_types):
24672469
func = getattr(self,func)
2470+
24682471
values = func().values
2469-
counts = self.size().values
2472+
counts = self.size().fillna(0).values
24702473
values = np.repeat(values, com._ensure_platform_int(counts))
2474+
if any(counts == 0):
2475+
values = self._try_cast(values, self._selected_obj)
24712476

24722477
return self._set_result_index_ordered(Series(values))
24732478

@@ -2502,8 +2507,11 @@ def true_and_notnull(x, *args, **kwargs):
25022507
return b and notnull(b)
25032508

25042509
try:
2505-
indices = [self._get_index(name) if true_and_notnull(group) else []
2506-
for name, group in self]
2510+
indices = []
2511+
for name, group in self:
2512+
if true_and_notnull(group) and name in self.indices:
2513+
indices.append(self.indices[name])
2514+
25072515
except ValueError:
25082516
raise TypeError("the filter must return a boolean result")
25092517
except TypeError:
@@ -3020,24 +3028,18 @@ def transform(self, func, *args, **kwargs):
30203028
if not result.columns.equals(obj.columns):
30213029
return self._transform_general(func, *args, **kwargs)
30223030

3023-
# a grouped that doesn't preserve the index, remap index based on the grouper
3024-
# and broadcast it
3025-
if ((not isinstance(obj.index,MultiIndex) and
3026-
type(result.index) != type(obj.index)) or
3027-
len(result.index) != len(obj.index)):
3028-
results = np.empty_like(obj.values, result.values.dtype)
3029-
indices = self.indices
3030-
for (name, group), (i, row) in zip(self, result.iterrows()):
3031+
results = np.empty_like(obj.values, result.values.dtype)
3032+
indices = self.indices
3033+
for (name, group), (i, row) in zip(self, result.iterrows()):
3034+
if name in indices:
30313035
indexer = indices[name]
30323036
results[indexer] = np.tile(row.values,len(indexer)).reshape(len(indexer),-1)
3033-
return DataFrame(results,columns=result.columns,index=obj.index).convert_objects()
30343037

3035-
# we can merge the result in
3036-
# GH 7383
3037-
names = result.columns
3038-
result = obj.merge(result, how='outer', left_index=True, right_index=True).iloc[:,-result.shape[1]:]
3039-
result.columns = names
3040-
return result
3038+
counts = self.size().fillna(0).values
3039+
if any(counts == 0):
3040+
results = self._try_cast(results, obj[result.columns])
3041+
3042+
return DataFrame(results,columns=result.columns,index=obj.index).convert_objects()
30413043

30423044
def _define_paths(self, func, *args, **kwargs):
30433045
if isinstance(func, compat.string_types):
@@ -3129,10 +3131,9 @@ def filter(self, func, dropna=True, *args, **kwargs):
31293131
pass
31303132

31313133
# interpret the result of the filter
3132-
if (isinstance(res, (bool, np.bool_)) or
3133-
np.isscalar(res) and isnull(res)):
3134-
if res and notnull(res):
3135-
indices.append(self._get_index(name))
3134+
if is_bool(res) or (lib.isscalar(res) and isnull(res)):
3135+
if res and notnull(res) and name in self.indices:
3136+
indices.append(self.indices[name])
31363137
else:
31373138
# non scalars aren't allowed
31383139
raise TypeError("filter function returned a %s, "

pandas/tests/test_categorical.py

+21
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,27 @@ def f(x):
18201820
expected['person_name'] = expected['person_name'].astype('object')
18211821
tm.assert_frame_equal(result, expected)
18221822

1823+
# GH 9921
1824+
# Monotonic
1825+
df = DataFrame({"a": [5, 15, 25]})
1826+
c = pd.cut(df.a, bins=[0,10,20,30,40])
1827+
tm.assert_series_equal(df.a.groupby(c).transform(sum), df['a'])
1828+
tm.assert_series_equal(df.a.groupby(c).transform(lambda xs: np.sum(xs)), df['a'])
1829+
tm.assert_frame_equal(df.groupby(c).transform(sum), df[['a']])
1830+
tm.assert_frame_equal(df.groupby(c).transform(lambda xs: np.max(xs)), df[['a']])
1831+
1832+
# Filter
1833+
tm.assert_series_equal(df.a.groupby(c).filter(np.all), df['a'])
1834+
tm.assert_frame_equal(df.groupby(c).filter(np.all), df)
1835+
1836+
# Non-monotonic
1837+
df = DataFrame({"a": [5, 15, 25, -5]})
1838+
c = pd.cut(df.a, bins=[-10, 0,10,20,30,40])
1839+
tm.assert_series_equal(df.a.groupby(c).transform(sum), df['a'])
1840+
tm.assert_series_equal(df.a.groupby(c).transform(lambda xs: np.sum(xs)), df['a'])
1841+
tm.assert_frame_equal(df.groupby(c).transform(sum), df[['a']])
1842+
tm.assert_frame_equal(df.groupby(c).transform(lambda xs: np.sum(xs)), df[['a']])
1843+
18231844
def test_pivot_table(self):
18241845

18251846
raw_cat1 = Categorical(["a","a","b","b"], categories=["a","b","z"], ordered=True)

pandas/tests/test_groupby.py

+6
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,12 @@ def demean(arr):
960960
g = df.groupby(pd.TimeGrouper('M'))
961961
g.transform(lambda x: x-1)
962962

963+
# GH 9700
964+
df = DataFrame({'a' : range(5, 10), 'b' : range(5)})
965+
result = df.groupby('a').transform(max)
966+
expected = DataFrame({'b' : range(5)})
967+
tm.assert_frame_equal(result, expected)
968+
963969
def test_transform_fast(self):
964970

965971
df = DataFrame( { 'id' : np.arange( 100000 ) / 3,

0 commit comments

Comments
 (0)