@@ -2642,9 +2642,37 @@ def raise_if_sum_is_zero(x):
2642
2642
s = pd .Series ([- 1 ,0 ,1 ,2 ])
2643
2643
grouper = s .apply (lambda x : x % 2 )
2644
2644
grouped = s .groupby (grouper )
2645
- self .assertRaises (ValueError ,
2645
+ self .assertRaises (TypeError ,
2646
2646
lambda : grouped .filter (raise_if_sum_is_zero ))
2647
2647
2648
+ def test_filter_bad_shapes (self ):
2649
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2650
+ s = df ['B' ]
2651
+ g_df = df .groupby ('B' )
2652
+ g_s = s .groupby (s )
2653
+
2654
+ f = lambda x : x
2655
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2656
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2657
+
2658
+ f = lambda x : x == 1
2659
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2660
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2661
+
2662
+ f = lambda x : np .outer (x , x )
2663
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2664
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2665
+
2666
+ def test_filter_nan_is_false (self ):
2667
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2668
+ s = df ['B' ]
2669
+ g_df = df .groupby (df ['B' ])
2670
+ g_s = s .groupby (s )
2671
+
2672
+ f = lambda x : np .nan
2673
+ assert_frame_equal (g_df .filter (f ), df .loc [[]])
2674
+ assert_series_equal (g_s .filter (f ), s [[]])
2675
+
2648
2676
def test_filter_against_workaround (self ):
2649
2677
np .random .seed (0 )
2650
2678
# Series of ints
@@ -2697,6 +2725,29 @@ def test_filter_against_workaround(self):
2697
2725
new_way = grouped .filter (lambda x : x ['ints' ].mean () > N / 20 )
2698
2726
assert_frame_equal (new_way .sort_index (), old_way .sort_index ())
2699
2727
2728
+ def test_filter_using_len (self ):
2729
+ # BUG GH4447
2730
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2731
+ grouped = df .groupby ('B' )
2732
+ actual = grouped .filter (lambda x : len (x ) > 2 )
2733
+ expected = DataFrame ({'A' : np .arange (2 , 6 ), 'B' : list ('bbbb' ), 'C' : np .arange (2 , 6 )}, index = np .arange (2 , 6 ))
2734
+ assert_frame_equal (actual , expected )
2735
+
2736
+ actual = grouped .filter (lambda x : len (x ) > 4 )
2737
+ expected = df .ix [[]]
2738
+ assert_frame_equal (actual , expected )
2739
+
2740
+ # Series have always worked properly, but we'll test anyway.
2741
+ s = df ['B' ]
2742
+ grouped = s .groupby (s )
2743
+ actual = grouped .filter (lambda x : len (x ) > 2 )
2744
+ expected = Series (4 * ['b' ], index = np .arange (2 , 6 ))
2745
+ assert_series_equal (actual , expected )
2746
+
2747
+ actual = grouped .filter (lambda x : len (x ) > 4 )
2748
+ expected = s [[]]
2749
+ assert_series_equal (actual , expected )
2750
+
2700
2751
def test_groupby_whitelist (self ):
2701
2752
from string import ascii_lowercase
2702
2753
letters = np .array (list (ascii_lowercase ))
0 commit comments