Skip to content

Commit 6fdb400

Browse files
committed
BUG: fix groupby with multiple non-compressed categoricals
1 parent 97aece1 commit 6fdb400

File tree

2 files changed

+64
-42
lines changed

2 files changed

+64
-42
lines changed

pandas/core/groupby.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,9 @@ def get_group_levels(self):
13611361
name_list = []
13621362
for ping, labels in zip(self.groupings, recons_labels):
13631363
labels = com._ensure_platform_int(labels)
1364-
name_list.append(ping.group_index.take(labels))
1364+
levels = ping.group_index.take(labels)
1365+
1366+
name_list.append(levels)
13651367

13661368
return name_list
13671369

@@ -1707,6 +1709,11 @@ def levels(self):
17071709
def names(self):
17081710
return [self.binlabels.name]
17091711

1712+
@property
1713+
def groupings(self):
1714+
# for compat
1715+
return None
1716+
17101717
def size(self):
17111718
"""
17121719
Compute group sizes
@@ -2632,7 +2639,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
26322639
if isinstance(values[0], DataFrame):
26332640
return self._concat_objects(keys, values,
26342641
not_indexed_same=not_indexed_same)
2635-
elif hasattr(self.grouper, 'groupings'):
2642+
elif self.grouper.groupings is not None:
26362643
if len(self.grouper.groupings) > 1:
26372644
key_index = MultiIndex.from_tuples(keys, names=key_names)
26382645

@@ -3058,7 +3065,7 @@ def _wrap_aggregated_output(self, output, names=None):
30583065
if self.axis == 1:
30593066
result = result.T
30603067

3061-
return result.convert_objects()
3068+
return self._reindex_output(result).convert_objects()
30623069

30633070
def _wrap_agged_blocks(self, items, blocks):
30643071
if not self.as_index:
@@ -3080,7 +3087,27 @@ def _wrap_agged_blocks(self, items, blocks):
30803087
if self.axis == 1:
30813088
result = result.T
30823089

3083-
return result.convert_objects()
3090+
return self._reindex_output(result).convert_objects()
3091+
3092+
def _reindex_output(self, result):
3093+
"""
3094+
if we have categorical groupers, then we want to make sure that
3095+
we have a fully reindex-output to the levels. These may have not participated in
3096+
the groupings (e.g. may have all been nan groups)
3097+
3098+
This can re-expand the output space
3099+
"""
3100+
groupings = self.grouper.groupings
3101+
if groupings is None:
3102+
return result
3103+
elif len(groupings) == 1:
3104+
return result
3105+
elif not any([ping._was_factor for ping in groupings]):
3106+
return result
3107+
3108+
levels_list = [ ping._group_index for ping in groupings ]
3109+
index = MultiIndex.from_product(levels_list, names=self.grouper.names)
3110+
return result.reindex(**{ self.obj._get_axis_name(self.axis) : index, 'copy' : False }).sortlevel()
30843111

30853112
def _iterate_column_groupbys(self):
30863113
for i, colname in enumerate(self._selected_obj.columns):

pandas/tests/test_categorical.py

+33-38
Original file line numberDiff line numberDiff line change
@@ -797,11 +797,6 @@ def test_repr(self):
797797
self.assertEqual(exp,a.__unicode__())
798798

799799

800-
def test_groupby(self):
801-
802-
result = self.cat['value_group'].unique()
803-
result = self.cat.groupby(['value_group'])['value_group'].count()
804-
805800
def test_groupby_sort(self):
806801

807802
# http://stackoverflow.com/questions/23814368/sorting-pandas-categorical-labels-after-groupby
@@ -872,52 +867,52 @@ def test_groupby(self):
872867
cats = Categorical(["a", "a", "a", "b", "b", "b", "c", "c", "c"], levels=["a","b","c","d"])
873868
data = DataFrame({"a":[1,1,1,2,2,2,3,4,5], "b":cats})
874869

870+
expected = DataFrame({ 'a' : Series([1,2,4,np.nan],index=Index(['a','b','c','d'],name='b')) })
875871
result = data.groupby("b").mean()
876-
result = result["a"].values
877-
exp = np.array([1,2,4,np.nan])
878-
self.assert_numpy_array_equivalent(result, exp)
879-
880-
### FIXME ###
881-
882-
#res = len(data.groupby("b"))
883-
#self.assertEqual(res ,4)
872+
tm.assert_frame_equal(result, expected)
884873

885874
raw_cat1 = Categorical(["a","a","b","b"], levels=["a","b","z"])
886875
raw_cat2 = Categorical(["c","d","c","d"], levels=["c","d","y"])
887876
df = DataFrame({"A":raw_cat1,"B":raw_cat2, "values":[1,2,3,4]})
888-
gb = df.groupby("A")
889877

890-
#idx = gb.indices
891-
#self.assertEqual(len(gb), 3)
892-
#num = 0
893-
#for _ in gb:
894-
# num +=1
895-
#self.assertEqual(len(gb), 3)
896-
#gb = df.groupby(["B"])
897-
#idx2 = gb.indices
898-
#self.assertEqual(len(gb), 3)
899-
#num = 0
900-
#for _ in gb:
901-
# num +=1
902-
#self.assertEqual(len(gb), 3)
903-
#gb = df.groupby(["A","B"])
904-
#res = len(gb)
905-
#idx3 = gb.indices
906-
#self.assertEqual(res, 9)
907-
#num = 0
908-
#for _ in gb:
909-
# num +=1
910-
#self.assertEqual(len(gb), 9)
878+
# single grouper
879+
gb = df.groupby("A")
880+
expected = DataFrame({ 'values' : Series([3,7,np.nan],index=Index(['a','b','z'],name='A')) })
881+
result = gb.sum()
882+
tm.assert_frame_equal(result, expected)
883+
884+
# multiple groupers
885+
gb = df.groupby(['A','B'])
886+
expected = DataFrame({ 'values' : Series([1,2,np.nan,3,4,np.nan,np.nan,np.nan,np.nan],
887+
index=pd.MultiIndex.from_product([['a','b','z'],['c','d','y']],names=['A','B'])) })
888+
result = gb.sum()
889+
tm.assert_frame_equal(result, expected)
890+
891+
# multiple groupers with a non-cat
892+
df = df.copy()
893+
df['C'] = ['foo','bar']*2
894+
gb = df.groupby(['A','B','C'])
895+
expected = DataFrame({ 'values' :
896+
Series(np.nan,index=pd.MultiIndex.from_product([['a','b','z'],
897+
['c','d','y'],
898+
['foo','bar']],
899+
names=['A','B','C']))
900+
}).sortlevel()
901+
expected.iloc[[1,2,7,8],0] = [1,2,3,4]
902+
result = gb.sum()
903+
tm.assert_frame_equal(result, expected)
911904

912905
def test_pivot_table(self):
913906

914907
raw_cat1 = Categorical(["a","a","b","b"], levels=["a","b","z"])
915908
raw_cat2 = Categorical(["c","d","c","d"], levels=["c","d","y"])
916909
df = DataFrame({"A":raw_cat1,"B":raw_cat2, "values":[1,2,3,4]})
917-
res = pd.pivot_table(df, values='values', index=['A', 'B'])
910+
result = pd.pivot_table(df, values='values', index=['A', 'B'])
918911

919-
### FIXME ###
920-
#self.assertEqual(len(res), 9)
912+
expected = Series([1,2,np.nan,3,4,np.nan,np.nan,np.nan,np.nan],
913+
index=pd.MultiIndex.from_product([['a','b','z'],['c','d','y']],names=['A','B']),
914+
name='values')
915+
tm.assert_series_equal(result, expected)
921916

922917
def test_count(self):
923918

0 commit comments

Comments
 (0)