Skip to content

Commit 6968da8

Browse files
committed
FIX raise when groupby selecting cols not in frame
1 parent 48a6849 commit 6968da8

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

pandas/core/groupby.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -2568,15 +2568,27 @@ def __getitem__(self, key):
25682568
if self._selection is not None:
25692569
raise Exception('Column(s) %s already selected' % self._selection)
25702570

2571-
if (isinstance(key, (list, tuple, Series, np.ndarray)) or
2572-
not self.as_index):
2571+
if isinstance(key, (list, tuple, Series, np.ndarray)):
2572+
if len(self.obj.columns.intersection(key)) != len(key):
2573+
bad_keys = list(set(key).difference(self.obj.columns))
2574+
raise KeyError("Columns not found: %s"
2575+
% str(bad_keys)[1:-1])
25732576
return DataFrameGroupBy(self.obj, self.grouper, selection=key,
25742577
grouper=self.grouper,
25752578
exclusions=self.exclusions,
25762579
as_index=self.as_index)
2580+
2581+
elif not self.as_index:
2582+
if key not in self.obj.columns:
2583+
raise KeyError("Column not found: %s" % key)
2584+
return DataFrameGroupBy(self.obj, self.grouper, selection=key,
2585+
grouper=self.grouper,
2586+
exclusions=self.exclusions,
2587+
as_index=self.as_index)
2588+
25772589
else:
2578-
if key not in self.obj: # pragma: no cover
2579-
raise KeyError(str(key))
2590+
if key not in self.obj:
2591+
raise KeyError("Column not found: %s" % key)
25802592
# kind of a kludge
25812593
return SeriesGroupBy(self.obj[key], selection=key,
25822594
grouper=self.grouper,

pandas/tests/test_groupby.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pandas.core.series import Series
1515
from pandas.util.testing import (assert_panel_equal, assert_frame_equal,
1616
assert_series_equal, assert_almost_equal,
17-
assert_index_equal)
17+
assert_index_equal, assertRaisesRegexp)
1818
from pandas.compat import(
1919
range, long, lrange, StringIO, lmap, lzip, map, zip, builtins, OrderedDict
2020
)
@@ -30,6 +30,7 @@
3030
import pandas.util.testing as tm
3131
import pandas as pd
3232

33+
3334
def commonSetUp(self):
3435
self.dateRange = bdate_range('1/1/2005', periods=250)
3536
self.stringIndex = Index([rands(8).upper() for x in range(250)])
@@ -72,7 +73,8 @@ def setUp(self):
7273
'B': ['one', 'one', 'two', 'three',
7374
'two', 'two', 'one', 'three'],
7475
'C': np.random.randn(8),
75-
'D': np.array(np.random.randn(8),dtype='float32')})
76+
'D': np.array(np.random.randn(8),
77+
dtype='float32')})
7678

7779
index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'],
7880
['one', 'two', 'three']],
@@ -114,7 +116,7 @@ def checkit(dtype):
114116

115117
assert_series_equal(agged, grouped.agg(np.mean)) # shorthand
116118
assert_series_equal(agged, grouped.mean())
117-
assert_series_equal(grouped.agg(np.sum),grouped.sum())
119+
assert_series_equal(grouped.agg(np.sum), grouped.sum())
118120

119121
transformed = grouped.transform(lambda x: x * x.sum())
120122
self.assertEqual(transformed[7], 12)
@@ -138,10 +140,20 @@ def checkit(dtype):
138140
# corner cases
139141
self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2)
140142

141-
142-
for dtype in ['int64','int32','float64','float32']:
143+
for dtype in ['int64', 'int32', 'float64', 'float32']:
143144
checkit(dtype)
144145

146+
def test_select_bad_cols(self):
147+
df = DataFrame([[1, 2]], columns=['A', 'B'])
148+
g = df.groupby('A')
149+
self.assertRaises(KeyError, g.__getitem__, ['C']) # g[['C']]
150+
151+
self.assertRaises(KeyError, g.__getitem__, ['A', 'C']) # g[['A', 'C']]
152+
with assertRaisesRegexp(KeyError, '^[^A]+$'):
153+
# A should not be referenced as a bad column...
154+
# will have to rethink regex if you change message!
155+
g[['A', 'C']]
156+
145157
def test_first_last_nth(self):
146158
# tests for first / last / nth
147159
grouped = self.df.groupby('A')

0 commit comments

Comments
 (0)