diff --git a/doc/source/reference/groupby.rst b/doc/source/reference/groupby.rst index 5f6bef2579d27..76cb53559f334 100644 --- a/doc/source/reference/groupby.rst +++ b/doc/source/reference/groupby.rst @@ -116,6 +116,7 @@ application to columns of a specific data type. DataFrameGroupBy.quantile DataFrameGroupBy.rank DataFrameGroupBy.resample + DataFrameGroupBy.sample DataFrameGroupBy.shift DataFrameGroupBy.size DataFrameGroupBy.skew diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 92f7c0f6b59a3..0ea6d22c64b4c 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -275,6 +275,7 @@ Other enhancements such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`) - :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`) - :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`) +- :class:`~pandas.core.groupby.generic.DataFrameGroupBy` and :class:`~pandas.core.groupby.generic.SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`) - :meth:`DataFrame.to_numpy` now supports the ``na_value`` keyword to control the NA sentinel in the output array (:issue:`33820`) - The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals` method, similarly to :meth:`Series.equals` (:issue:`27081`). diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 6183638ab587e..28009f7f8937a 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4868,6 +4868,10 @@ def sample( See Also -------- + DataFrameGroupBy.sample: Generates random samples from each group of a + DataFrame object. + SeriesGroupBy.sample: Generates random samples from each group of a + Series object. numpy.random.choice: Generates a random sample from a given 1-D numpy array. diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index 363286704ba95..08352d737dee0 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -180,6 +180,7 @@ def _gotitem(self, key, ndim, subset=None): "tail", "take", "transform", + "sample", ] ) # Valid values of `name` for `groupby.transform(name)` diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e385a78142ba5..c2be8d96402df 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -23,6 +23,7 @@ class providing the base-class of operations. List, Mapping, Optional, + Sequence, Tuple, Type, TypeVar, @@ -2695,6 +2696,118 @@ def _reindex_output( return output.reset_index(drop=True) + def sample( + self, + n: Optional[int] = None, + frac: Optional[float] = None, + replace: bool = False, + weights: Optional[Union[Sequence, Series]] = None, + random_state=None, + ): + """ + Return a random sample of items from each group. + + You can use `random_state` for reproducibility. + + .. versionadded:: 1.1.0 + + Parameters + ---------- + n : int, optional + Number of items to return for each group. Cannot be used with + `frac` and must be no larger than the smallest group unless + `replace` is True. Default is one if `frac` is None. + frac : float, optional + Fraction of items to return. Cannot be used with `n`. + replace : bool, default False + Allow or disallow sampling of the same row more than once. + weights : list-like, optional + Default None results in equal probability weighting. + If passed a list-like then values must have the same length as + the underlying DataFrame or Series object and will be used as + sampling probabilities after normalization within each group. + Values must be non-negative with at least one positive element + within each group. + random_state : int, array-like, BitGenerator, np.random.RandomState, optional + If int, array-like, or BitGenerator (NumPy>=1.17), seed for + random number generator + If np.random.RandomState, use as numpy RandomState object. + + Returns + ------- + Series or DataFrame + A new object of same type as caller containing items randomly + sampled within each group from the caller object. + + See Also + -------- + DataFrame.sample: Generate random samples from a DataFrame object. + numpy.random.choice: Generate a random sample from a given 1-D numpy + array. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} + ... ) + >>> df + a b + 0 red 0 + 1 red 1 + 2 blue 2 + 3 blue 3 + 4 black 4 + 5 black 5 + + Select one row at random for each distinct value in column a. The + `random_state` argument can be used to guarantee reproducibility: + + >>> df.groupby("a").sample(n=1, random_state=1) + a b + 4 black 4 + 2 blue 2 + 1 red 1 + + Set `frac` to sample fixed proportions rather than counts: + + >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) + 5 5 + 2 2 + 0 0 + Name: b, dtype: int64 + + Control sample probabilities within groups by setting weights: + + >>> df.groupby("a").sample( + ... n=1, + ... weights=[1, 1, 1, 0, 0, 1], + ... random_state=1, + ... ) + a b + 5 black 5 + 2 blue 2 + 0 red 0 + """ + from pandas.core.reshape.concat import concat + + if weights is not None: + weights = Series(weights, index=self._selected_obj.index) + ws = [weights[idx] for idx in self.indices.values()] + else: + ws = [None] * self.ngroups + + if random_state is not None: + random_state = com.random_state(random_state) + + samples = [ + obj.sample( + n=n, frac=frac, replace=replace, weights=w, random_state=random_state + ) + for (_, obj), w in zip(self, ws) + ] + + return concat(samples, axis=self.axis) + @doc(GroupBy) def get_groupby( diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py new file mode 100644 index 0000000000000..412e3e8f732de --- /dev/null +++ b/pandas/tests/groupby/test_sample.py @@ -0,0 +1,125 @@ +import pytest + +from pandas import DataFrame, Index, Series +import pandas._testing as tm + + +@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)]) +def test_groupby_sample_balanced_groups_shape(n, frac): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=n, frac=frac) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=n, frac=frac) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_unbalanced_groups_shape(): + values = [1] * 10 + [2] * 20 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=5) + values = [1] * 5 + [2] * 5 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=5) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_index_value_spans_groups(): + values = [1] * 3 + [2] * 3 + df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2]) + + result = df.groupby("a").sample(n=2) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_n_and_frac_raises(): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + msg = "Please enter a value for `frac` OR `n`, not both" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=1, frac=1.0) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=1, frac=1.0) + + +def test_groupby_sample_frac_gt_one_without_replacement_raises(): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + msg = "Replace has to be set to `True` when upsampling the population `frac` > 1." + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(frac=1.5, replace=False) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(frac=1.5, replace=False) + + +@pytest.mark.parametrize("n", [-1, 1.5]) +def test_groupby_sample_invalid_n_raises(n): + df = DataFrame({"a": [1, 2], "b": [1, 2]}) + + if n < 0: + msg = "Please provide positive value" + else: + msg = "Only integers accepted as `n` values" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=n) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=n) + + +def test_groupby_sample_oversample(): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(frac=2.0, replace=True) + values = [1] * 20 + [2] * 20 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(frac=2.0, replace=True) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_without_n_or_frac(): + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) + + result = df.groupby("a").sample(n=None, frac=None) + expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=None, frac=None) + expected = Series([1, 2], name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_with_weights(): + values = [1] * 2 + [2] * 2 + df = DataFrame({"a": values, "b": values}, index=Index(["w", "x", "y", "z"])) + + result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = DataFrame({"a": values, "b": values}, index=Index(["w", "w", "y", "y"])) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = Series(values, name="b", index=Index(["w", "w", "y", "y"])) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_whitelist.py b/pandas/tests/groupby/test_whitelist.py index 6b33049a664de..1598cc24ba6fb 100644 --- a/pandas/tests/groupby/test_whitelist.py +++ b/pandas/tests/groupby/test_whitelist.py @@ -328,6 +328,7 @@ def test_tab_completion(mframe): "rolling", "expanding", "pipe", + "sample", } assert results == expected