diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 031088c4e5672..770c4ae27dbdb 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -2568,15 +2568,27 @@ def __getitem__(self, key): if self._selection is not None: raise Exception('Column(s) %s already selected' % self._selection) - if (isinstance(key, (list, tuple, Series, np.ndarray)) or - not self.as_index): + if isinstance(key, (list, tuple, Series, np.ndarray)): + if len(self.obj.columns.intersection(key)) != len(key): + bad_keys = list(set(key).difference(self.obj.columns)) + raise KeyError("Columns not found: %s" + % str(bad_keys)[1:-1]) return DataFrameGroupBy(self.obj, self.grouper, selection=key, grouper=self.grouper, exclusions=self.exclusions, as_index=self.as_index) + + elif not self.as_index: + if key not in self.obj.columns: + raise KeyError("Column not found: %s" % key) + return DataFrameGroupBy(self.obj, self.grouper, selection=key, + grouper=self.grouper, + exclusions=self.exclusions, + as_index=self.as_index) + else: - if key not in self.obj: # pragma: no cover - raise KeyError(str(key)) + if key not in self.obj: + raise KeyError("Column not found: %s" % key) # kind of a kludge return SeriesGroupBy(self.obj[key], selection=key, grouper=self.grouper, diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 8bbc8e6326639..01bff979ee9f6 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -14,7 +14,7 @@ from pandas.core.series import Series from pandas.util.testing import (assert_panel_equal, assert_frame_equal, assert_series_equal, assert_almost_equal, - assert_index_equal) + assert_index_equal, assertRaisesRegexp) from pandas.compat import( range, long, lrange, StringIO, lmap, lzip, map, zip, builtins, OrderedDict ) @@ -30,6 +30,7 @@ import pandas.util.testing as tm import pandas as pd + def commonSetUp(self): self.dateRange = bdate_range('1/1/2005', periods=250) self.stringIndex = Index([rands(8).upper() for x in range(250)]) @@ -72,7 +73,8 @@ def setUp(self): 'B': ['one', 'one', 'two', 'three', 'two', 'two', 'one', 'three'], 'C': np.random.randn(8), - 'D': np.array(np.random.randn(8),dtype='float32')}) + 'D': np.array(np.random.randn(8), + dtype='float32')}) index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'], ['one', 'two', 'three']], @@ -114,7 +116,7 @@ def checkit(dtype): assert_series_equal(agged, grouped.agg(np.mean)) # shorthand assert_series_equal(agged, grouped.mean()) - assert_series_equal(grouped.agg(np.sum),grouped.sum()) + assert_series_equal(grouped.agg(np.sum), grouped.sum()) transformed = grouped.transform(lambda x: x * x.sum()) self.assertEqual(transformed[7], 12) @@ -138,10 +140,20 @@ def checkit(dtype): # corner cases self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2) - - for dtype in ['int64','int32','float64','float32']: + for dtype in ['int64', 'int32', 'float64', 'float32']: checkit(dtype) + def test_select_bad_cols(self): + df = DataFrame([[1, 2]], columns=['A', 'B']) + g = df.groupby('A') + self.assertRaises(KeyError, g.__getitem__, ['C']) # g[['C']] + + self.assertRaises(KeyError, g.__getitem__, ['A', 'C']) # g[['A', 'C']] + with assertRaisesRegexp(KeyError, '^[^A]+$'): + # A should not be referenced as a bad column... + # will have to rethink regex if you change message! + g[['A', 'C']] + def test_first_last_nth(self): # tests for first / last / nth grouped = self.df.groupby('A')