Skip to content

Commit 1dbed44

Browse files
committed
BUG: Cannot sample on DataFrameGroupBy with weights when index is specified
1 parent a7402c1 commit 1dbed44

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

pandas/core/groupby/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3076,7 +3076,7 @@ def sample(
30763076

30773077
if weights is not None:
30783078
weights = Series(weights, index=self._selected_obj.index)
3079-
ws = [weights[idx] for idx in self.indices.values()]
3079+
ws = [weights.iloc[idx] for idx in self.indices.values()]
30803080
else:
30813081
ws = [None] * self.ngroups
30823082

pandas/tests/groupby/test_sample.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,18 @@ def test_groupby_sample_without_n_or_frac():
116116
tm.assert_series_equal(result, expected)
117117

118118

119-
def test_groupby_sample_with_weights():
119+
@pytest.mark.parametrize(
120+
"index, expect_index",
121+
[(["w", "x", "y", "z"], ["w", "w", "y", "y"]), ([3, 4, 5, 6], [3, 3, 5, 5])]
122+
)
123+
def test_groupby_sample_with_weights(index, expect_index):
120124
values = [1] * 2 + [2] * 2
121-
df = DataFrame({"a": values, "b": values}, index=Index(["w", "x", "y", "z"]))
125+
df = DataFrame({"a": values, "b": values}, index=Index(index))
122126

123127
result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0])
124-
expected = DataFrame({"a": values, "b": values}, index=Index(["w", "w", "y", "y"]))
128+
expected = DataFrame({"a": values, "b": values}, index=Index(expect_index))
125129
tm.assert_frame_equal(result, expected)
126130

127131
result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0])
128-
expected = Series(values, name="b", index=Index(["w", "w", "y", "y"]))
132+
expected = Series(values, name="b", index=Index(expect_index))
129133
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)