-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Implement groupby.sample #34069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b91b767
d0cf785
0656332
cb5f105
40966bf
4f2e8da
904bdcd
0db1ed7
cbaf4a5
07dacf2
2935645
3e159a8
2397c3a
8c3dfd8
21923a7
11f3d77
e6579d3
cf41a58
37037c2
540af35
611a1b4
1d3c4d2
c2e1615
c0f9ef1
5ffd4ad
1f733d6
a11487d
0369d22
8ceeed1
279cc3c
daf278b
88ef72c
fb55e08
1a3016a
e2d71e3
c136c1f
372da0e
b1bf65f
04789a1
48eea97
b07b377
b447f85
62f7a15
68d8d4a
572cc6c
97034ae
ad0bd61
05a1ba5
e31a119
56a49a0
27cb1ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -2630,6 +2630,93 @@ def _reindex_output( | |||
|
||||
return output.reset_index(drop=True) | ||||
|
||||
def sample( | ||||
self, | ||||
n: Optional[int] = None, | ||||
frac: Optional[float] = None, | ||||
replace: bool = False, | ||||
weights=None, | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
random_state=None, | ||||
): | ||||
""" | ||||
Return a random sample of items from each group. | ||||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
You can use `random_state` for reproducibility. | ||||
|
||||
.. versionadded:: 1.1.0 | ||||
|
||||
Parameters | ||||
---------- | ||||
n : int, optional | ||||
Number of items to return. Cannot be used with `frac`. | ||||
Default = 1 if `frac` is None. | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
frac : float, optional | ||||
Fraction of items to return. Cannot be used with `n`. | ||||
replace : bool, default False | ||||
Allow or disallow sampling of the same row more than once. | ||||
weights : list-like, optional | ||||
Default None results in equal probability weighting. | ||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
If passed a list-like then values must have the same length as | ||||
the underlying object and will be used as sampling probabilities | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
after normalization within each group. | ||||
random_state : int, array-like, BitGenerator, np.random.RandomState, optional | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
If int, array-like, or BitGenerator (NumPy>=1.17), seed for | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It it is a BitGenerator, do you use a Generator to produce the random samples or a RandomState. Best practice is to use a Generator since RandomState is effectively frozen in time. If an int, it is used as a seed for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is following a pattern similar to the one used in pandas.core.generic.sample of processing the random_state according to pandas.core.common.random_state: Line 394 in c71bfc3
|
||||
random number generator | ||||
If np.random.RandomState, use as numpy RandomState object. | ||||
|
||||
Returns | ||||
------- | ||||
Series or DataFrame | ||||
A new object of same type as caller containing items randomly | ||||
sampled within each group from the caller object. | ||||
|
||||
See Also | ||||
-------- | ||||
DataFrame.sample: Generate random samples from a DataFrame object. | ||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
numpy.random.choice: Generate a random sample from a given 1-D numpy | ||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
array. | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
Examples | ||||
-------- | ||||
>>> df = pd.DataFrame( | ||||
... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} | ||||
... ) | ||||
>>> df | ||||
a b | ||||
0 red 0 | ||||
1 red 1 | ||||
2 blue 2 | ||||
3 blue 3 | ||||
4 black 4 | ||||
5 black 5 | ||||
>>> df.groupby("a").sample(n=1, random_state=1) | ||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
a b | ||||
4 black 4 | ||||
2 blue 2 | ||||
0 red 0 | ||||
>>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) | ||||
5 5 | ||||
3 3 | ||||
1 1 | ||||
Name: b, dtype: int64 | ||||
dsaxton marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
""" | ||||
from pandas.core.reshape.concat import concat | ||||
|
||||
if weights is not None: | ||||
weights = Series(weights, index=self._selected_obj.index) | ||||
ws = [weights[idx] for idx in self.indices.values()] | ||||
else: | ||||
ws = [None] * self.ngroups | ||||
|
||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
samples = [ | ||||
self._selected_obj.loc[idx].sample( | ||||
n=n, frac=frac, replace=replace, weights=w, random_state=random_state | ||||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
) | ||||
for idx, w in zip(self.indices.values(), ws) | ||||
] | ||||
|
||||
return concat(samples, axis=self.axis) | ||||
|
||||
|
||||
GroupBy._add_numeric_operations() | ||||
|
||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import pytest | ||
|
||
from pandas import DataFrame, Index, Series | ||
import pandas._testing as tm | ||
|
||
|
||
@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)]) | ||
def test_groupby_sample_balanced_groups_shape(n, frac): | ||
values = [1] * 10 + [2] * 10 | ||
df = DataFrame({"a": values, "b": values}) | ||
|
||
result = df.groupby("a").sample(n=n, frac=frac) | ||
values = [1] * 2 + [2] * 2 | ||
expected = DataFrame({"a": values, "b": values}, index=result.index) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.groupby("a")["b"].sample(n=n, frac=frac) | ||
expected = Series(values, name="b", index=result.index) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_groupby_sample_unbalanced_groups_shape(): | ||
values = [1] * 10 + [2] * 20 | ||
df = DataFrame({"a": values, "b": values}) | ||
|
||
result = df.groupby("a").sample(n=5) | ||
values = [1] * 5 + [2] * 5 | ||
expected = DataFrame({"a": values, "b": values}, index=result.index) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.groupby("a")["b"].sample(n=5) | ||
expected = Series(values, name="b", index=result.index) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_groupby_sample_n_and_frac_raises(): | ||
df = DataFrame({"a": [1, 2], "b": [1, 2]}) | ||
msg = "Please enter a value for `frac` OR `n`, not both" | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a").sample(n=1, frac=1.0) | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a")["b"].sample(n=1, frac=1.0) | ||
|
||
|
||
def test_groupby_sample_frac_gt_one_without_replacement_raises(): | ||
df = DataFrame({"a": [1, 2], "b": [1, 2]}) | ||
msg = "Replace has to be set to `True` when upsampling the population `frac` > 1." | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a").sample(frac=1.5, replace=False) | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a")["b"].sample(frac=1.5, replace=False) | ||
|
||
|
||
@pytest.mark.parametrize("n", [-1, 1.5]) | ||
def test_groupby_sample_invalid_n_raises(n): | ||
df = DataFrame({"a": [1, 2], "b": [1, 2]}) | ||
|
||
if n < 0: | ||
msg = "Please provide positive value" | ||
else: | ||
msg = "Only integers accepted as `n` values" | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a").sample(n=n) | ||
|
||
with pytest.raises(ValueError, match=msg): | ||
df.groupby("a")["b"].sample(n=n) | ||
|
||
|
||
def test_groupby_sample_oversample(): | ||
values = [1] * 10 + [2] * 10 | ||
df = DataFrame({"a": values, "b": values}) | ||
|
||
result = df.groupby("a").sample(frac=2.0, replace=True) | ||
values = [1] * 20 + [2] * 20 | ||
expected = DataFrame({"a": values, "b": values}, index=result.index) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.groupby("a")["b"].sample(frac=2.0, replace=True) | ||
expected = Series(values, name="b", index=result.index) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_groupby_sample_without_n_or_frac(): | ||
values = [1] * 10 + [2] * 10 | ||
df = DataFrame({"a": values, "b": values}) | ||
|
||
result = df.groupby("a").sample(n=None, frac=None) | ||
expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.groupby("a")["b"].sample(n=None, frac=None) | ||
expected = Series([1, 2], name="b", index=result.index) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_groupby_sample_with_weights(): | ||
values = [1] * 2 + [2] * 2 | ||
df = DataFrame({"a": values, "b": values}, index=Index([0, 1, 2, 3])) | ||
|
||
result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) | ||
expected = DataFrame({"a": values, "b": values}, index=Index([0, 0, 2, 2])) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) | ||
expected = Series(values, name="b", index=Index([0, 0, 2, 2])) | ||
tm.assert_series_equal(result, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need the full path to these classes in the docs.