@@ -116,14 +116,21 @@ def test_groupby_sample_without_n_or_frac():
116
116
tm .assert_series_equal (result , expected )
117
117
118
118
119
- def test_groupby_sample_with_weights ():
119
+ @pytest .mark .parametrize (
120
+ "index, expect_index" ,
121
+ [
122
+ (["w" , "x" , "y" , "z" ], ["w" , "w" , "y" , "y" ]),
123
+ ([3 , 4 , 5 , 6 ], [3 , 3 , 5 , 5 ])
124
+ ]
125
+ )
126
+ def test_groupby_sample_with_weights (index , expect_index ):
120
127
values = [1 ] * 2 + [2 ] * 2
121
- df = DataFrame ({"a" : values , "b" : values }, index = Index ([ "w" , "x" , "y" , "z" ] ))
128
+ df = DataFrame ({"a" : values , "b" : values }, index = Index (index ))
122
129
123
130
result = df .groupby ("a" ).sample (n = 2 , replace = True , weights = [1 , 0 , 1 , 0 ])
124
- expected = DataFrame ({"a" : values , "b" : values }, index = Index ([ "w" , "w" , "y" , "y" ] ))
131
+ expected = DataFrame ({"a" : values , "b" : values }, index = Index (expect_index ))
125
132
tm .assert_frame_equal (result , expected )
126
133
127
134
result = df .groupby ("a" )["b" ].sample (n = 2 , replace = True , weights = [1 , 0 , 1 , 0 ])
128
- expected = Series (values , name = "b" , index = Index ([ "w" , "w" , "y" , "y" ] ))
135
+ expected = Series (values , name = "b" , index = Index (expect_index ))
129
136
tm .assert_series_equal (result , expected )
0 commit comments