@@ -116,14 +116,18 @@ def test_groupby_sample_without_n_or_frac():
116
116
tm .assert_series_equal (result , expected )
117
117
118
118
119
- def test_groupby_sample_with_weights ():
119
+ @pytest .mark .parametrize (
120
+ "index, expect_index" ,
121
+ [(["w" , "x" , "y" , "z" ], ["w" , "w" , "y" , "y" ]), ([3 , 4 , 5 , 6 ], [3 , 3 , 5 , 5 ])],
122
+ )
123
+ def test_groupby_sample_with_weights (index , expect_index ):
120
124
values = [1 ] * 2 + [2 ] * 2
121
- df = DataFrame ({"a" : values , "b" : values }, index = Index ([ "w" , "x" , "y" , "z" ] ))
125
+ df = DataFrame ({"a" : values , "b" : values }, index = Index (index ))
122
126
123
127
result = df .groupby ("a" ).sample (n = 2 , replace = True , weights = [1 , 0 , 1 , 0 ])
124
- expected = DataFrame ({"a" : values , "b" : values }, index = Index ([ "w" , "w" , "y" , "y" ] ))
128
+ expected = DataFrame ({"a" : values , "b" : values }, index = Index (expect_index ))
125
129
tm .assert_frame_equal (result , expected )
126
130
127
131
result = df .groupby ("a" )["b" ].sample (n = 2 , replace = True , weights = [1 , 0 , 1 , 0 ])
128
- expected = Series (values , name = "b" , index = Index ([ "w" , "w" , "y" , "y" ] ))
132
+ expected = Series (values , name = "b" , index = Index (expect_index ))
129
133
tm .assert_series_equal (result , expected )
0 commit comments