Skip to content

Commit 2d876be

Browse files
committed
Merge pull request #6578 from hayd/groupby_selecte_badcols
FIX raise when groupby selecting cols not in frame
2 parents d5f9493 + 6968da8 commit 2d876be

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

pandas/core/groupby.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2853,15 +2853,27 @@ def __getitem__(self, key):
28532853
if self._selection is not None:
28542854
raise Exception('Column(s) %s already selected' % self._selection)
28552855

2856-
if (isinstance(key, (list, tuple, Series, np.ndarray)) or
2857-
not self.as_index):
2856+
if isinstance(key, (list, tuple, Series, np.ndarray)):
2857+
if len(self.obj.columns.intersection(key)) != len(key):
2858+
bad_keys = list(set(key).difference(self.obj.columns))
2859+
raise KeyError("Columns not found: %s"
2860+
% str(bad_keys)[1:-1])
28582861
return DataFrameGroupBy(self.obj, self.grouper, selection=key,
28592862
grouper=self.grouper,
28602863
exclusions=self.exclusions,
28612864
as_index=self.as_index)
2865+
2866+
elif not self.as_index:
2867+
if key not in self.obj.columns:
2868+
raise KeyError("Column not found: %s" % key)
2869+
return DataFrameGroupBy(self.obj, self.grouper, selection=key,
2870+
grouper=self.grouper,
2871+
exclusions=self.exclusions,
2872+
as_index=self.as_index)
2873+
28622874
else:
2863-
if key not in self.obj: # pragma: no cover
2864-
raise KeyError(str(key))
2875+
if key not in self.obj:
2876+
raise KeyError("Column not found: %s" % key)
28652877
# kind of a kludge
28662878
return SeriesGroupBy(self.obj[key], selection=key,
28672879
grouper=self.grouper,

pandas/tests/test_groupby.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pandas.core.series import Series
1616
from pandas.util.testing import (assert_panel_equal, assert_frame_equal,
1717
assert_series_equal, assert_almost_equal,
18-
assert_index_equal)
18+
assert_index_equal, assertRaisesRegexp)
1919
from pandas.compat import(
2020
range, long, lrange, StringIO, lmap, lzip, map,
2121
zip, builtins, OrderedDict
@@ -33,6 +33,7 @@
3333
import pandas as pd
3434
from numpy.testing import assert_equal
3535

36+
3637
def commonSetUp(self):
3738
self.dateRange = bdate_range('1/1/2005', periods=250)
3839
self.stringIndex = Index([rands(8).upper() for x in range(250)])
@@ -75,7 +76,8 @@ def setUp(self):
7576
'B': ['one', 'one', 'two', 'three',
7677
'two', 'two', 'one', 'three'],
7778
'C': np.random.randn(8),
78-
'D': np.array(np.random.randn(8),dtype='float32')})
79+
'D': np.array(np.random.randn(8),
80+
dtype='float32')})
7981

8082
index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'],
8183
['one', 'two', 'three']],
@@ -117,7 +119,7 @@ def checkit(dtype):
117119

118120
assert_series_equal(agged, grouped.agg(np.mean)) # shorthand
119121
assert_series_equal(agged, grouped.mean())
120-
assert_series_equal(grouped.agg(np.sum),grouped.sum())
122+
assert_series_equal(grouped.agg(np.sum), grouped.sum())
121123

122124
transformed = grouped.transform(lambda x: x * x.sum())
123125
self.assertEqual(transformed[7], 12)
@@ -141,10 +143,20 @@ def checkit(dtype):
141143
# corner cases
142144
self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2)
143145

144-
145-
for dtype in ['int64','int32','float64','float32']:
146+
for dtype in ['int64', 'int32', 'float64', 'float32']:
146147
checkit(dtype)
147148

149+
def test_select_bad_cols(self):
150+
df = DataFrame([[1, 2]], columns=['A', 'B'])
151+
g = df.groupby('A')
152+
self.assertRaises(KeyError, g.__getitem__, ['C']) # g[['C']]
153+
154+
self.assertRaises(KeyError, g.__getitem__, ['A', 'C']) # g[['A', 'C']]
155+
with assertRaisesRegexp(KeyError, '^[^A]+$'):
156+
# A should not be referenced as a bad column...
157+
# will have to rethink regex if you change message!
158+
g[['A', 'C']]
159+
148160
def test_first_last_nth(self):
149161
# tests for first / last / nth
150162
grouped = self.df.groupby('A')

0 commit comments

Comments
 (0)