Skip to content

ENH: Implement groupby.sample #34069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Jun 14, 2020
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
b91b767
ENH: Implement groupby.sample
dsaxton May 8, 2020
d0cf785
Add test
dsaxton May 8, 2020
0656332
Add tag
dsaxton May 8, 2020
cb5f105
Troubleshoot CI
dsaxton May 8, 2020
40966bf
doc nit
dsaxton May 8, 2020
4f2e8da
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 8, 2020
904bdcd
Move tag
dsaxton May 8, 2020
0db1ed7
Dispatch and allow weights
dsaxton May 8, 2020
cbaf4a5
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 8, 2020
07dacf2
black
dsaxton May 8, 2020
2935645
Add doc examples
dsaxton May 8, 2020
3e159a8
Fixup
dsaxton May 8, 2020
2397c3a
Another fixup
dsaxton May 8, 2020
8c3dfd8
Edit tests
dsaxton May 8, 2020
21923a7
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 9, 2020
11f3d77
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 10, 2020
e6579d3
Update docstring
dsaxton May 11, 2020
cf41a58
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 11, 2020
37037c2
Don't use selected_obj.index
dsaxton May 12, 2020
540af35
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 12, 2020
611a1b4
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 14, 2020
1d3c4d2
Iterate over self
dsaxton May 14, 2020
c2e1615
Sequence
dsaxton May 21, 2020
c0f9ef1
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 21, 2020
5ffd4ad
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 25, 2020
1f733d6
Fix examples
dsaxton May 25, 2020
a11487d
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton May 30, 2020
0369d22
Add random_state tests
dsaxton Jun 1, 2020
8ceeed1
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton Jun 1, 2020
279cc3c
Copy less
dsaxton Jun 1, 2020
daf278b
random -> arange
dsaxton Jun 1, 2020
88ef72c
Skip for numpy version
dsaxton Jun 1, 2020
fb55e08
Revert "Skip for numpy version"
dsaxton Jun 1, 2020
1a3016a
Try again
dsaxton Jun 1, 2020
e2d71e3
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton Jun 1, 2020
c136c1f
Fix
dsaxton Jun 2, 2020
372da0e
Delete
dsaxton Jun 2, 2020
b1bf65f
random_state
dsaxton Jun 2, 2020
04789a1
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton Jun 2, 2020
48eea97
Doc
dsaxton Jun 2, 2020
b07b377
not None
dsaxton Jun 4, 2020
b447f85
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton Jun 4, 2020
62f7a15
doc
dsaxton Jun 4, 2020
68d8d4a
doc
dsaxton Jun 5, 2020
572cc6c
Merge remote-tracking branch 'upstream/master' into groupby-sample
dsaxton Jun 5, 2020
97034ae
Add weights example
dsaxton Jun 5, 2020
ad0bd61
Fix weights index and adjust test
dsaxton Jun 5, 2020
05a1ba5
Update docstring
dsaxton Jun 5, 2020
e31a119
Merge remote-tracking branch 'upstream/master' into groupby-sample
Jun 8, 2020
56a49a0
Update doc
Jun 10, 2020
27cb1ba
Merge remote-tracking branch 'upstream/master' into groupby-sample
Jun 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ Other enhancements
such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`)
- :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`)
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need the full path to these classes in the docs.

- :meth:`DataFrame.to_numpy` now supports the ``na_value`` keyword to control the NA sentinel in the output array (:issue:`33820`)
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
method, similarly to :meth:`Series.equals` (:issue:`27081`).
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4859,6 +4859,10 @@ def sample(

See Also
--------
DataFrameGroupBy.sample: Generates random samples from each group of a
DataFrame object.
SeriesGroupBy.sample: Generates random samples from each group of a
Series object.
numpy.random.choice: Generates a random sample from a given 1-D numpy
array.

Expand Down
1 change: 1 addition & 0 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _gotitem(self, key, ndim, subset=None):
"tail",
"take",
"transform",
"sample",
]
)
# Valid values of `name` for `groupby.transform(name)`
Expand Down
98 changes: 98 additions & 0 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class providing the base-class of operations.
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -2696,6 +2697,103 @@ def _reindex_output(

return output.reset_index(drop=True)

def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
replace: bool = False,
weights: Optional[Union[Sequence, Series]] = None,
random_state=None,
):
"""
Return a random sample of items from each group.

You can use `random_state` for reproducibility.

.. versionadded:: 1.1.0

Parameters
----------
n : int, optional
Number of items to return. Cannot be used with `frac`.
Default = 1 if `frac` is None.
frac : float, optional
Fraction of items to return. Cannot be used with `n`.
replace : bool, default False
Allow or disallow sampling of the same row more than once.
weights : list-like, optional
Default None results in equal probability weighting.
If passed a list-like then values must have the same length as
the underlying object and will be used as sampling probabilities
after normalization within each group.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator (NumPy>=1.17), seed for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It it is a BitGenerator, do you use a Generator to produce the random samples or a RandomState. Best practice is to use a Generator since RandomState is effectively frozen in time. If an int, it is used as a seed for np.random.default_rng() or RandomState if NumPy >= 1.17?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is following a pattern similar to the one used in pandas.core.generic.sample of processing the random_state according to pandas.core.common.random_state:

def random_state(state=None):

random number generator
If np.random.RandomState, use as numpy RandomState object.

Returns
-------
Series or DataFrame
A new object of same type as caller containing items randomly
sampled within each group from the caller object.

See Also
--------
DataFrame.sample: Generate random samples from a DataFrame object.
numpy.random.choice: Generate a random sample from a given 1-D numpy
array.

Examples
--------
>>> df = pd.DataFrame(
... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)}
... )
>>> df
a b
0 red 0
1 red 1
2 blue 2
3 blue 3
4 black 4
5 black 5

Select one row at random for each distinct value in column a. The
`random_state` argument can be used to guarantee reproducibility:

>>> df.groupby("a").sample(n=1, random_state=1)
a b
4 black 4
2 blue 2
1 red 1

Set `frac` to sample fixed proportions rather than counts:

>>> df.groupby("a")["b"].sample(frac=0.5, random_state=2)
5 5
2 2
0 0
Name: b, dtype: int64
"""
from pandas.core.reshape.concat import concat

if weights is not None:
weights = Series(weights)
ws = [weights[idx] for idx in self.indices.values()]
else:
ws = [None] * self.ngroups

if random_state:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think this is enough, you need to always have a random_state here that is consistent across the entire groupby.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either is fine. Either we get a random state from NumPy's global random state initially and re-use it, or we have each group draw from the global random state pool. It's similar to these two calls

  1. .sample(random_state=0) # each call uses the seed 0
  2. .sample(random_state=np.random.RandomState(0)) # each call makes an independent draw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant to make this random_state is not None (didn't consider other "falsey" values)

random_state = com.random_state(random_state)

samples = [
obj.sample(
n=n, frac=frac, replace=replace, weights=w, random_state=random_state
)
for (_, obj), w in zip(self, ws)
]

return concat(samples, axis=self.axis)


@doc(GroupBy)
def get_groupby(
Expand Down
125 changes: 125 additions & 0 deletions pandas/tests/groupby/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pytest

from pandas import DataFrame, Index, Series
import pandas._testing as tm


@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)])
def test_groupby_sample_balanced_groups_shape(n, frac):
values = [1] * 10 + [2] * 10
df = DataFrame({"a": values, "b": values})

result = df.groupby("a").sample(n=n, frac=frac)
values = [1] * 2 + [2] * 2
expected = DataFrame({"a": values, "b": values}, index=result.index)
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(n=n, frac=frac)
expected = Series(values, name="b", index=result.index)
tm.assert_series_equal(result, expected)


def test_groupby_sample_unbalanced_groups_shape():
values = [1] * 10 + [2] * 20
df = DataFrame({"a": values, "b": values})

result = df.groupby("a").sample(n=5)
values = [1] * 5 + [2] * 5
expected = DataFrame({"a": values, "b": values}, index=result.index)
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(n=5)
expected = Series(values, name="b", index=result.index)
tm.assert_series_equal(result, expected)


def test_groupby_sample_index_value_spans_groups():
values = [1] * 3 + [2] * 3
df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2])

result = df.groupby("a").sample(n=2)
values = [1] * 2 + [2] * 2
expected = DataFrame({"a": values, "b": values}, index=result.index)
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(n=2)
expected = Series(values, name="b", index=result.index)
tm.assert_series_equal(result, expected)


def test_groupby_sample_n_and_frac_raises():
df = DataFrame({"a": [1, 2], "b": [1, 2]})
msg = "Please enter a value for `frac` OR `n`, not both"

with pytest.raises(ValueError, match=msg):
df.groupby("a").sample(n=1, frac=1.0)

with pytest.raises(ValueError, match=msg):
df.groupby("a")["b"].sample(n=1, frac=1.0)


def test_groupby_sample_frac_gt_one_without_replacement_raises():
df = DataFrame({"a": [1, 2], "b": [1, 2]})
msg = "Replace has to be set to `True` when upsampling the population `frac` > 1."

with pytest.raises(ValueError, match=msg):
df.groupby("a").sample(frac=1.5, replace=False)

with pytest.raises(ValueError, match=msg):
df.groupby("a")["b"].sample(frac=1.5, replace=False)


@pytest.mark.parametrize("n", [-1, 1.5])
def test_groupby_sample_invalid_n_raises(n):
df = DataFrame({"a": [1, 2], "b": [1, 2]})

if n < 0:
msg = "Please provide positive value"
else:
msg = "Only integers accepted as `n` values"

with pytest.raises(ValueError, match=msg):
df.groupby("a").sample(n=n)

with pytest.raises(ValueError, match=msg):
df.groupby("a")["b"].sample(n=n)


def test_groupby_sample_oversample():
values = [1] * 10 + [2] * 10
df = DataFrame({"a": values, "b": values})

result = df.groupby("a").sample(frac=2.0, replace=True)
values = [1] * 20 + [2] * 20
expected = DataFrame({"a": values, "b": values}, index=result.index)
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(frac=2.0, replace=True)
expected = Series(values, name="b", index=result.index)
tm.assert_series_equal(result, expected)


def test_groupby_sample_without_n_or_frac():
values = [1] * 10 + [2] * 10
df = DataFrame({"a": values, "b": values})

result = df.groupby("a").sample(n=None, frac=None)
expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index)
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(n=None, frac=None)
expected = Series([1, 2], name="b", index=result.index)
tm.assert_series_equal(result, expected)


def test_groupby_sample_with_weights():
values = [1] * 2 + [2] * 2
df = DataFrame({"a": values, "b": values}, index=Index([0, 1, 2, 3]))

result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0])
expected = DataFrame({"a": values, "b": values}, index=Index([0, 0, 2, 2]))
tm.assert_frame_equal(result, expected)

result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0])
expected = Series(values, name="b", index=Index([0, 0, 2, 2]))
tm.assert_series_equal(result, expected)
1 change: 1 addition & 0 deletions pandas/tests/groupby/test_whitelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def test_tab_completion(mframe):
"rolling",
"expanding",
"pipe",
"sample",
}
assert results == expected

Expand Down