@@ -2641,9 +2641,37 @@ def raise_if_sum_is_zero(x):
2641
2641
s = pd .Series ([- 1 ,0 ,1 ,2 ])
2642
2642
grouper = s .apply (lambda x : x % 2 )
2643
2643
grouped = s .groupby (grouper )
2644
- self .assertRaises (ValueError ,
2644
+ self .assertRaises (TypeError ,
2645
2645
lambda : grouped .filter (raise_if_sum_is_zero ))
2646
2646
2647
+ def test_filter_bad_shapes (self ):
2648
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2649
+ s = df ['B' ]
2650
+ g_df = df .groupby ('B' )
2651
+ g_s = s .groupby (s )
2652
+
2653
+ f = lambda x : x
2654
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2655
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2656
+
2657
+ f = lambda x : x == 1
2658
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2659
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2660
+
2661
+ f = lambda x : np .outer (x , x )
2662
+ self .assertRaises (TypeError , lambda : g_df .filter (f ))
2663
+ self .assertRaises (TypeError , lambda : g_s .filter (f ))
2664
+
2665
+ def test_filter_nan_is_false (self ):
2666
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2667
+ s = df ['B' ]
2668
+ g_df = df .groupby (df ['B' ])
2669
+ g_s = s .groupby (s )
2670
+
2671
+ f = lambda x : np .nan
2672
+ assert_frame_equal (g_df .filter (f ), df .loc [[]])
2673
+ assert_series_equal (g_s .filter (f ), s [[]])
2674
+
2647
2675
def test_filter_against_workaround (self ):
2648
2676
np .random .seed (0 )
2649
2677
# Series of ints
@@ -2696,6 +2724,29 @@ def test_filter_against_workaround(self):
2696
2724
new_way = grouped .filter (lambda x : x ['ints' ].mean () > N / 20 )
2697
2725
assert_frame_equal (new_way .sort_index (), old_way .sort_index ())
2698
2726
2727
+ def test_filter_using_len (self ):
2728
+ # BUG GH4447
2729
+ df = DataFrame ({'A' : np .arange (8 ), 'B' : list ('aabbbbcc' ), 'C' : np .arange (8 )})
2730
+ grouped = df .groupby ('B' )
2731
+ actual = grouped .filter (lambda x : len (x ) > 2 )
2732
+ expected = DataFrame ({'A' : np .arange (2 , 6 ), 'B' : list ('bbbb' ), 'C' : np .arange (2 , 6 )}, index = np .arange (2 , 6 ))
2733
+ assert_frame_equal (actual , expected )
2734
+
2735
+ actual = grouped .filter (lambda x : len (x ) > 4 )
2736
+ expected = df .ix [[]]
2737
+ assert_frame_equal (actual , expected )
2738
+
2739
+ # Series have always worked properly, but we'll test anyway.
2740
+ s = df ['B' ]
2741
+ grouped = s .groupby (s )
2742
+ actual = grouped .filter (lambda x : len (x ) > 2 )
2743
+ expected = Series (4 * ['b' ], index = np .arange (2 , 6 ))
2744
+ assert_series_equal (actual , expected )
2745
+
2746
+ actual = grouped .filter (lambda x : len (x ) > 4 )
2747
+ expected = s [[]]
2748
+ assert_series_equal (actual , expected )
2749
+
2699
2750
def test_groupby_whitelist (self ):
2700
2751
from string import ascii_lowercase
2701
2752
letters = np .array (list (ascii_lowercase ))
0 commit comments