Skip to content

Commit c3d31c7

Browse files
committed
Merge pull request #7871 from cpcloud/groupby-filter-fix
BUG/FIX: groupby should raise on multi-valued filter
2 parents 9da121e + f0fc4b5 commit c3d31c7

File tree

3 files changed

+42
-29
lines changed

3 files changed

+42
-29
lines changed

doc/source/v0.15.0.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ Bug Fixes
346346

347347

348348

349-
349+
- Bug in ``GroupBy.filter()`` where fast path vs. slow path made the filter
350+
return a non scalar value that appeared valid but wasnt' (:issue:`7870`).
350351

351352

352353

pandas/core/groupby.py

+14-28
Original file line numberDiff line numberDiff line change
@@ -2945,48 +2945,34 @@ def filter(self, func, dropna=True, *args, **kwargs):
29452945
>>> grouped = df.groupby(lambda x: mapping[x])
29462946
>>> grouped.filter(lambda x: x['A'].sum() + x['B'].sum() > 0)
29472947
"""
2948-
from pandas.tools.merge import concat
29492948

29502949
indices = []
29512950

29522951
obj = self._selected_obj
29532952
gen = self.grouper.get_iterator(obj, axis=self.axis)
29542953

2955-
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
2956-
2957-
path = None
29582954
for name, group in gen:
29592955
object.__setattr__(group, 'name', name)
29602956

2961-
if path is None:
2962-
# Try slow path and fast path.
2963-
try:
2964-
path, res = self._choose_path(fast_path, slow_path, group)
2965-
except Exception: # pragma: no cover
2966-
res = fast_path(group)
2967-
path = fast_path
2968-
else:
2969-
res = path(group)
2957+
res = func(group)
29702958

2971-
def add_indices():
2972-
indices.append(self._get_index(name))
2959+
try:
2960+
res = res.squeeze()
2961+
except AttributeError: # allow e.g., scalars and frames to pass
2962+
pass
29732963

29742964
# interpret the result of the filter
2975-
if isinstance(res, (bool, np.bool_)):
2976-
if res:
2977-
add_indices()
2965+
if (isinstance(res, (bool, np.bool_)) or
2966+
np.isscalar(res) and isnull(res)):
2967+
if res and notnull(res):
2968+
indices.append(self._get_index(name))
29782969
else:
2979-
if getattr(res, 'ndim', None) == 1:
2980-
val = res.ravel()[0]
2981-
if val and notnull(val):
2982-
add_indices()
2983-
else:
2984-
2985-
# in theory you could do .all() on the boolean result ?
2986-
raise TypeError("the filter must return a boolean result")
2970+
# non scalars aren't allowed
2971+
raise TypeError("filter function returned a %s, "
2972+
"but expected a scalar bool" %
2973+
type(res).__name__)
29872974

2988-
filtered = self._apply_filter(indices, dropna)
2989-
return filtered
2975+
return self._apply_filter(indices, dropna)
29902976

29912977

29922978
class DataFrameGroupBy(NDFrameGroupBy):

pandas/tests/test_groupby.py

+26
Original file line numberDiff line numberDiff line change
@@ -3968,6 +3968,32 @@ def test_filter_has_access_to_grouped_cols(self):
39683968
filt = g.filter(lambda x: x['A'].sum() == 2)
39693969
assert_frame_equal(filt, df.iloc[[0, 1]])
39703970

3971+
def test_filter_enforces_scalarness(self):
3972+
df = pd.DataFrame([
3973+
['best', 'a', 'x'],
3974+
['worst', 'b', 'y'],
3975+
['best', 'c', 'x'],
3976+
['best','d', 'y'],
3977+
['worst','d', 'y'],
3978+
['worst','d', 'y'],
3979+
['best','d', 'z'],
3980+
], columns=['a', 'b', 'c'])
3981+
with tm.assertRaisesRegexp(TypeError, 'filter function returned a.*'):
3982+
df.groupby('c').filter(lambda g: g['a'] == 'best')
3983+
3984+
def test_filter_non_bool_raises(self):
3985+
df = pd.DataFrame([
3986+
['best', 'a', 1],
3987+
['worst', 'b', 1],
3988+
['best', 'c', 1],
3989+
['best','d', 1],
3990+
['worst','d', 1],
3991+
['worst','d', 1],
3992+
['best','d', 1],
3993+
], columns=['a', 'b', 'c'])
3994+
with tm.assertRaisesRegexp(TypeError, 'filter function returned a.*'):
3995+
df.groupby('a').filter(lambda g: g.c.mean())
3996+
39713997
def test_index_label_overlaps_location(self):
39723998
# checking we don't have any label/location confusion in the
39733999
# the wake of GH5375

0 commit comments

Comments
 (0)