Skip to content

Commit 889b9f2

Browse files
committed
BUG/TST: Test filter corner cases, and treat NaN as False.
1 parent bea34eb commit 889b9f2

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

pandas/core/groupby.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pandas.util.decorators import cache_readonly, Appender
1818
import pandas.core.algorithms as algos
1919
import pandas.core.common as com
20-
from pandas.core.common import _possibly_downcast_to_dtype, notnull
20+
from pandas.core.common import _possibly_downcast_to_dtype, isnull, notnull
2121

2222
import pandas.lib as lib
2323
import pandas.algos as _algos
@@ -1605,8 +1605,19 @@ def filter(self, func, dropna=True, *args, **kwargs):
16051605
else:
16061606
wrapper = lambda x: func(x, *args, **kwargs)
16071607

1608-
indexers = [self.obj.index.get_indexer(group.index) \
1609-
if wrapper(group) else [] for _ , group in self]
1608+
# Interpret np.nan as False.
1609+
def true_and_notnull(x, *args, **kwargs):
1610+
b = wrapper(x, *args, **kwargs)
1611+
return b and notnull(b)
1612+
1613+
try:
1614+
indexers = [self.obj.index.get_indexer(group.index) \
1615+
if true_and_notnull(group) else [] \
1616+
for _ , group in self]
1617+
except ValueError:
1618+
raise TypeError("the filter must return a boolean result")
1619+
except TypeError:
1620+
raise TypeError("the filter must return a boolean result")
16101621

16111622
if len(indexers) == 0:
16121623
filtered = self.obj.take([]) # because np.concatenate would fail
@@ -2124,7 +2135,8 @@ def add_indexer():
21242135
add_indexer()
21252136
else:
21262137
if getattr(res,'ndim',None) == 1:
2127-
if res.ravel()[0]:
2138+
val = res.ravel()[0]
2139+
if val and notnull(val):
21282140
add_indexer()
21292141
else:
21302142

pandas/tests/test_groupby.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -2642,9 +2642,37 @@ def raise_if_sum_is_zero(x):
26422642
s = pd.Series([-1,0,1,2])
26432643
grouper = s.apply(lambda x: x % 2)
26442644
grouped = s.groupby(grouper)
2645-
self.assertRaises(ValueError,
2645+
self.assertRaises(TypeError,
26462646
lambda: grouped.filter(raise_if_sum_is_zero))
26472647

2648+
def test_filter_bad_shapes(self):
2649+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2650+
s = df['B']
2651+
g_df = df.groupby('B')
2652+
g_s = s.groupby(s)
2653+
2654+
f = lambda x: x
2655+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2656+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2657+
2658+
f = lambda x: x == 1
2659+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2660+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2661+
2662+
f = lambda x: np.outer(x, x)
2663+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2664+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2665+
2666+
def test_filter_nan_is_false(self):
2667+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2668+
s = df['B']
2669+
g_df = df.groupby(df['B'])
2670+
g_s = s.groupby(s)
2671+
2672+
f = lambda x: np.nan
2673+
assert_frame_equal(g_df.filter(f), df.loc[[]])
2674+
assert_series_equal(g_s.filter(f), s[[]])
2675+
26482676
def test_filter_against_workaround(self):
26492677
np.random.seed(0)
26502678
# Series of ints
@@ -2697,6 +2725,29 @@ def test_filter_against_workaround(self):
26972725
new_way = grouped.filter(lambda x: x['ints'].mean() > N/20)
26982726
assert_frame_equal(new_way.sort_index(), old_way.sort_index())
26992727

2728+
def test_filter_using_len(self):
2729+
# BUG GH4447
2730+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2731+
grouped = df.groupby('B')
2732+
actual = grouped.filter(lambda x: len(x) > 2)
2733+
expected = DataFrame({'A': np.arange(2, 6), 'B': list('bbbb'), 'C': np.arange(2, 6)}, index=np.arange(2, 6))
2734+
assert_frame_equal(actual, expected)
2735+
2736+
actual = grouped.filter(lambda x: len(x) > 4)
2737+
expected = df.ix[[]]
2738+
assert_frame_equal(actual, expected)
2739+
2740+
# Series have always worked properly, but we'll test anyway.
2741+
s = df['B']
2742+
grouped = s.groupby(s)
2743+
actual = grouped.filter(lambda x: len(x) > 2)
2744+
expected = Series(4*['b'], index=np.arange(2, 6))
2745+
assert_series_equal(actual, expected)
2746+
2747+
actual = grouped.filter(lambda x: len(x) > 4)
2748+
expected = s[[]]
2749+
assert_series_equal(actual, expected)
2750+
27002751
def test_groupby_whitelist(self):
27012752
from string import ascii_lowercase
27022753
letters = np.array(list(ascii_lowercase))

0 commit comments

Comments
 (0)