Skip to content

Commit 9ec30ad

Browse files
committed
Merge pull request #5096 from danielballan/filter-len-test
TST: Groupby filter tests involved len, closing #4447
2 parents 1aa55d8 + 889b9f2 commit 9ec30ad

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

pandas/core/groupby.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pandas.util.decorators import cache_readonly, Appender
1919
import pandas.core.algorithms as algos
2020
import pandas.core.common as com
21-
from pandas.core.common import _possibly_downcast_to_dtype, notnull
21+
from pandas.core.common import _possibly_downcast_to_dtype, isnull, notnull
2222

2323
import pandas.lib as lib
2424
import pandas.algos as _algos
@@ -1624,8 +1624,19 @@ def filter(self, func, dropna=True, *args, **kwargs):
16241624
else:
16251625
wrapper = lambda x: func(x, *args, **kwargs)
16261626

1627-
indexers = [self.obj.index.get_indexer(group.index) \
1628-
if wrapper(group) else [] for _ , group in self]
1627+
# Interpret np.nan as False.
1628+
def true_and_notnull(x, *args, **kwargs):
1629+
b = wrapper(x, *args, **kwargs)
1630+
return b and notnull(b)
1631+
1632+
try:
1633+
indexers = [self.obj.index.get_indexer(group.index) \
1634+
if true_and_notnull(group) else [] \
1635+
for _ , group in self]
1636+
except ValueError:
1637+
raise TypeError("the filter must return a boolean result")
1638+
except TypeError:
1639+
raise TypeError("the filter must return a boolean result")
16291640

16301641
if len(indexers) == 0:
16311642
filtered = self.obj.take([]) # because np.concatenate would fail
@@ -2144,7 +2155,8 @@ def add_indexer():
21442155
add_indexer()
21452156
else:
21462157
if getattr(res,'ndim',None) == 1:
2147-
if res.ravel()[0]:
2158+
val = res.ravel()[0]
2159+
if val and notnull(val):
21482160
add_indexer()
21492161
else:
21502162

pandas/tests/test_groupby.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -2641,9 +2641,37 @@ def raise_if_sum_is_zero(x):
26412641
s = pd.Series([-1,0,1,2])
26422642
grouper = s.apply(lambda x: x % 2)
26432643
grouped = s.groupby(grouper)
2644-
self.assertRaises(ValueError,
2644+
self.assertRaises(TypeError,
26452645
lambda: grouped.filter(raise_if_sum_is_zero))
26462646

2647+
def test_filter_bad_shapes(self):
2648+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2649+
s = df['B']
2650+
g_df = df.groupby('B')
2651+
g_s = s.groupby(s)
2652+
2653+
f = lambda x: x
2654+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2655+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2656+
2657+
f = lambda x: x == 1
2658+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2659+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2660+
2661+
f = lambda x: np.outer(x, x)
2662+
self.assertRaises(TypeError, lambda: g_df.filter(f))
2663+
self.assertRaises(TypeError, lambda: g_s.filter(f))
2664+
2665+
def test_filter_nan_is_false(self):
2666+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2667+
s = df['B']
2668+
g_df = df.groupby(df['B'])
2669+
g_s = s.groupby(s)
2670+
2671+
f = lambda x: np.nan
2672+
assert_frame_equal(g_df.filter(f), df.loc[[]])
2673+
assert_series_equal(g_s.filter(f), s[[]])
2674+
26472675
def test_filter_against_workaround(self):
26482676
np.random.seed(0)
26492677
# Series of ints
@@ -2696,6 +2724,29 @@ def test_filter_against_workaround(self):
26962724
new_way = grouped.filter(lambda x: x['ints'].mean() > N/20)
26972725
assert_frame_equal(new_way.sort_index(), old_way.sort_index())
26982726

2727+
def test_filter_using_len(self):
2728+
# BUG GH4447
2729+
df = DataFrame({'A': np.arange(8), 'B': list('aabbbbcc'), 'C': np.arange(8)})
2730+
grouped = df.groupby('B')
2731+
actual = grouped.filter(lambda x: len(x) > 2)
2732+
expected = DataFrame({'A': np.arange(2, 6), 'B': list('bbbb'), 'C': np.arange(2, 6)}, index=np.arange(2, 6))
2733+
assert_frame_equal(actual, expected)
2734+
2735+
actual = grouped.filter(lambda x: len(x) > 4)
2736+
expected = df.ix[[]]
2737+
assert_frame_equal(actual, expected)
2738+
2739+
# Series have always worked properly, but we'll test anyway.
2740+
s = df['B']
2741+
grouped = s.groupby(s)
2742+
actual = grouped.filter(lambda x: len(x) > 2)
2743+
expected = Series(4*['b'], index=np.arange(2, 6))
2744+
assert_series_equal(actual, expected)
2745+
2746+
actual = grouped.filter(lambda x: len(x) > 4)
2747+
expected = s[[]]
2748+
assert_series_equal(actual, expected)
2749+
26992750
def test_groupby_whitelist(self):
27002751
from string import ascii_lowercase
27012752
letters = np.array(list(ascii_lowercase))

0 commit comments

Comments
 (0)