@@ -2061,30 +2061,59 @@ def test_rank_object_raises(self, ties_method, ascending, na_option,
2061
2061
ascending = ascending ,
2062
2062
na_option = na_option , pct = pct )
2063
2063
2064
+ @pytest .mark .parametrize ("mix_groupings" , [True , False ])
2064
2065
@pytest .mark .parametrize ("as_series" , [True , False ])
2066
+ @pytest .mark .parametrize ("val1,val2" , [
2067
+ ('foo' , 'bar' ), (1 , 2 ), (1. , 2. )])
2065
2068
@pytest .mark .parametrize ("fill_method,limit,exp_vals" , [
2066
2069
("ffill" , None ,
2067
- [np .nan , np .nan , 'foo ' , 'foo ' , 'foo ' , 'bar ' , 'bar ' , 'bar ' ]),
2070
+ [np .nan , np .nan , 'val1 ' , 'val1 ' , 'val1 ' , 'val2 ' , 'val2 ' , 'val2 ' ]),
2068
2071
("ffill" , 1 ,
2069
- [np .nan , np .nan , 'foo ' , 'foo ' , np .nan , 'bar ' , 'bar ' , np .nan ]),
2072
+ [np .nan , np .nan , 'val1 ' , 'val1 ' , np .nan , 'val2 ' , 'val2 ' , np .nan ]),
2070
2073
("bfill" , None ,
2071
- ['foo ' , 'foo ' , 'foo ' , 'bar ' , 'bar ' , 'bar ' , np .nan , np .nan ]),
2074
+ ['val1 ' , 'val1 ' , 'val1 ' , 'val2 ' , 'val2 ' , 'val2 ' , np .nan , np .nan ]),
2072
2075
("bfill" , 1 ,
2073
- [np .nan , 'foo ' , 'foo ' , np .nan , 'bar ' , 'bar ' , np .nan , np .nan ])
2076
+ [np .nan , 'val1 ' , 'val1 ' , np .nan , 'val2 ' , 'val2 ' , np .nan , np .nan ])
2074
2077
])
2075
- def test_group_fill_methods (self , as_series , fill_method , limit , exp_vals ):
2076
- vals = [np .nan , np .nan , 'foo' , np .nan , np .nan , 'bar' , np .nan , np .nan ]
2077
- keys = ['a' ] * len (vals ) + ['b' ] * len (vals )
2078
- df = DataFrame ({'key' : keys , 'val' : vals * 2 })
2079
-
2078
+ def test_group_fill_methods (self , mix_groupings , as_series , val1 , val2 ,
2079
+ fill_method , limit , exp_vals ):
2080
+ vals = [np .nan , np .nan , val1 , np .nan , np .nan , val2 , np .nan , np .nan ]
2081
+ _exp_vals = list (exp_vals )
2082
+ # Overwrite placeholder values
2083
+ for index , exp_val in enumerate (_exp_vals ):
2084
+ if exp_val == 'val1' :
2085
+ _exp_vals [index ] = val1
2086
+ elif exp_val == 'val2' :
2087
+ _exp_vals [index ] = val2
2088
+
2089
+ # Need to modify values and expectations depending on the
2090
+ # Series / DataFrame that we ultimately want to generate
2091
+ if mix_groupings : # ['a', 'b', 'a, 'b', ...]
2092
+ keys = ['a' , 'b' ] * len (vals )
2093
+
2094
+ def interweave (list_obj ):
2095
+ temp = list ()
2096
+ for x in list_obj :
2097
+ temp .extend ([x , x ])
2098
+
2099
+ return temp
2100
+
2101
+ _exp_vals = interweave (_exp_vals )
2102
+ vals = interweave (vals )
2103
+ else : # ['a', 'a', 'a', ... 'b', 'b', 'b']
2104
+ keys = ['a' ] * len (vals ) + ['b' ] * len (vals )
2105
+ _exp_vals = _exp_vals * 2
2106
+ vals = vals * 2
2107
+
2108
+ df = DataFrame ({'key' : keys , 'val' : vals })
2080
2109
if as_series :
2081
2110
result = getattr (
2082
2111
df .groupby ('key' )['val' ], fill_method )(limit = limit )
2083
- exp = Series (exp_vals * 2 , name = 'val' )
2112
+ exp = Series (_exp_vals , name = 'val' )
2084
2113
assert_series_equal (result , exp )
2085
2114
else :
2086
2115
result = getattr (df .groupby ('key' ), fill_method )(limit = limit )
2087
- exp = DataFrame ({'key' : keys , 'val' : exp_vals * 2 })
2116
+ exp = DataFrame ({'key' : keys , 'val' : _exp_vals })
2088
2117
assert_frame_equal (result , exp )
2089
2118
2090
2119
def test_dont_clobber_name_column (self ):
0 commit comments