From b91b767c39dabaf6114cac2b5df4ed887f0c6d60 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Thu, 7 May 2020 22:20:31 -0500 Subject: [PATCH 01/34] ENH: Implement groupby.sample --- doc/source/whatsnew/v1.1.0.rst | 2 +- pandas/core/groupby/groupby.py | 63 ++++++++++++++++++++++++ pandas/tests/groupby/test_sample.py | 75 +++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 pandas/tests/groupby/test_sample.py diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 9c424f70b1ee0..20079c247d6e3 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -150,7 +150,7 @@ Other enhancements such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`) - :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`) - :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`) -- +- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 81c3fd7ad9e89..2d97f0f15d384 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2630,6 +2630,69 @@ def _reindex_output( return output.reset_index(drop=True) + def sample( + self, + n: Optional[int] = None, + frac: Optional[float] = None, + replace: bool = False, + random_state=None, + ): + """ + Return a random sample of items from each group. + + You can use `random_state` for reproducibility. + + Parameters + ---------- + n : int, optional + Number of items to return. Cannot be used with `frac`. + Default = 1 if `frac` = None. + frac : float, optional + Fraction of items to return. Cannot be used with `n`. + replace : bool, default False + Allow or disallow sampling of the same row more than once. + random_state : int, array-like, BitGenerator, np.random.RandomState, optional + If int, array-like, or BitGenerator (NumPy>=1.17), seed for + random number generator + If np.random.RandomState, use as numpy RandomState object. + + Returns + ------- + Series or DataFrame + A new object of same type as caller containing items randomly + sampled within each group from the caller object. + + See Also + -------- + DataFrame.sample: Generate random samples from a DataFrame object. + numpy.random.choice: Generates a random sample from a given 1-D numpy + array. + """ + if frac is not None and frac > 1 and not replace: + raise ValueError("replace must be set to True when frac > 1") + if n is not None and (n != int(n) or n < 0): + raise ValueError("Only non-negative integers accepted as n values") + + if n is None and frac is None: + ns = [1] * self.ngroups + elif n is None and frac is not None: + ns = [int(frac * len(i)) for i in self.indices.values()] + elif n is not None and frac is None: + ns = [n] * self.ngroups + else: + raise ValueError("Please enter a value for frac or n but not both") + + rs = com.random_state(random_state) + + idx_list = [ + rs.choice(i, m, replace=replace) for i, m in zip(self.indices.values(), ns) + ] + + cons = self._selected_obj.index._constructor + idx = cons(np.concatenate(idx_list)) + + return self._selected_obj.loc[idx] + GroupBy._add_numeric_operations() diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py new file mode 100644 index 0000000000000..bf7033fda7cc4 --- /dev/null +++ b/pandas/tests/groupby/test_sample.py @@ -0,0 +1,75 @@ +import pytest + +from pandas import DataFrame, Series +import pandas._testing as tm + + +@pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)]) +def test_groupby_sample_balanced_groups_shape(n, frac): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + + result = df.groupby("a").sample(n=n, frac=frac) + expected = DataFrame({"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=n, frac=frac) + expected = Series([1] * 4, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_unbalanced_groups_shape(): + df = DataFrame({"a": [1] * 10 + [2] * 20, "b": [1] * 30}) + + result = df.groupby("a").sample(n=5) + expected = DataFrame({"a": [1] * 5 + [2] * 5, "b": [1] * 10}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=5) + expected = Series([1] * 10, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_n_and_frac_raises(): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + msg = "Please enter a value for frac or n but not both" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=1, frac=1.0) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=1, frac=1.0) + + +def test_groupby_sample_frac_gt_one_without_replacement_raises(): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + msg = "replace must be set to True when frac > 1" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(frac=1.5, replace=False) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(frac=1.5, replace=False) + + +def test_groupby_sample_oversample(): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + + result = df.groupby("a").sample(frac=2.0, replace=True) + expected = DataFrame({"a": [1] * 20 + [2] * 20, "b": [1] * 40}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(frac=2.0, replace=True) + expected = Series([1] * 40, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + +def test_groupby_sample_without_n_or_frac(): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + + result = df.groupby("a").sample(n=None, frac=None) + expected = DataFrame({"a": [1, 2], "b": [1, 1]}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=None, frac=None) + expected = Series([1, 1], name="b", index=result.index) + tm.assert_series_equal(result, expected) From d0cf785e70c135cda57f926da0b9ecb8b2d94085 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 09:18:31 -0500 Subject: [PATCH 02/34] Add test --- pandas/tests/groupby/test_sample.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index bf7033fda7cc4..5cc2a86313831 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -51,6 +51,18 @@ def test_groupby_sample_frac_gt_one_without_replacement_raises(): df.groupby("a")["b"].sample(frac=1.5, replace=False) +@pytest.mark.parametrize("n", [-1, 1.5]) +def test_groupby_sample_invalid_n(n): + df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + msg = "Only non-negative integers accepted as n values" + + with pytest.raises(ValueError, match=msg): + df.groupby("a").sample(n=n) + + with pytest.raises(ValueError, match=msg): + df.groupby("a")["b"].sample(n=n) + + def test_groupby_sample_oversample(): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) From 065633260537b37fdddb38756661f22ea41bf667 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 09:22:44 -0500 Subject: [PATCH 03/34] Add tag --- pandas/core/groupby/groupby.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 2d97f0f15d384..986cd5da38976 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2656,6 +2656,8 @@ def sample( random number generator If np.random.RandomState, use as numpy RandomState object. + .. versionadded:: 1.1.0 + Returns ------- Series or DataFrame From cb5f1053b7ca0021eea8819e74a6ca64c29547e7 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 09:55:55 -0500 Subject: [PATCH 04/34] Troubleshoot CI --- pandas/core/groupby/base.py | 1 + pandas/tests/groupby/test_whitelist.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index 363286704ba95..08352d737dee0 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -180,6 +180,7 @@ def _gotitem(self, key, ndim, subset=None): "tail", "take", "transform", + "sample", ] ) # Valid values of `name` for `groupby.transform(name)` diff --git a/pandas/tests/groupby/test_whitelist.py b/pandas/tests/groupby/test_whitelist.py index 8e387e9202ef6..453201666e2e1 100644 --- a/pandas/tests/groupby/test_whitelist.py +++ b/pandas/tests/groupby/test_whitelist.py @@ -328,6 +328,7 @@ def test_tab_completion(mframe): "rolling", "expanding", "pipe", + "sample", } assert results == expected From 40966bf5ca4be32e587c0658e121fb42b694d743 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 10:51:08 -0500 Subject: [PATCH 05/34] doc nit --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 986cd5da38976..963c90363f130 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2646,7 +2646,7 @@ def sample( ---------- n : int, optional Number of items to return. Cannot be used with `frac`. - Default = 1 if `frac` = None. + Default = 1 if `frac` is None. frac : float, optional Fraction of items to return. Cannot be used with `n`. replace : bool, default False From 904bdcdb95927889f3a20456810d5c74c3c6f699 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 10:58:21 -0500 Subject: [PATCH 06/34] Move tag --- pandas/core/groupby/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 963c90363f130..845a67cc8e480 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2642,6 +2642,8 @@ def sample( You can use `random_state` for reproducibility. + .. versionadded:: 1.1.0 + Parameters ---------- n : int, optional @@ -2656,8 +2658,6 @@ def sample( random number generator If np.random.RandomState, use as numpy RandomState object. - .. versionadded:: 1.1.0 - Returns ------- Series or DataFrame From 0db1ed794cadcba7bc403132889dd3f8e499944e Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 13:49:29 -0500 Subject: [PATCH 07/34] Dispatch and allow weights --- pandas/core/groupby/groupby.py | 38 ++++++++++++++--------------- pandas/tests/groupby/test_sample.py | 28 ++++++++++++++++++--- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 845a67cc8e480..ff5eb111a6c2d 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2635,6 +2635,7 @@ def sample( n: Optional[int] = None, frac: Optional[float] = None, replace: bool = False, + weights=None, random_state=None, ): """ @@ -2653,6 +2654,11 @@ def sample( Fraction of items to return. Cannot be used with `n`. replace : bool, default False Allow or disallow sampling of the same row more than once. + weights : list-like, optional + Default None results in equal probability weighting. + If passed a list-like then values must have the same length as + the underlying object and will be used as sampling probabilities + after normalization within each group. random_state : int, array-like, BitGenerator, np.random.RandomState, optional If int, array-like, or BitGenerator (NumPy>=1.17), seed for random number generator @@ -2670,30 +2676,22 @@ def sample( numpy.random.choice: Generates a random sample from a given 1-D numpy array. """ - if frac is not None and frac > 1 and not replace: - raise ValueError("replace must be set to True when frac > 1") - if n is not None and (n != int(n) or n < 0): - raise ValueError("Only non-negative integers accepted as n values") - - if n is None and frac is None: - ns = [1] * self.ngroups - elif n is None and frac is not None: - ns = [int(frac * len(i)) for i in self.indices.values()] - elif n is not None and frac is None: - ns = [n] * self.ngroups - else: - raise ValueError("Please enter a value for frac or n but not both") + from pandas.core.reshape.concat import concat - rs = com.random_state(random_state) + if weights is not None: + weights = Series(weights, index=self._selected_obj.index) + ws = [weights[idx] for idx in self.indices.values()] + else: + ws = [None] * self.ngroups - idx_list = [ - rs.choice(i, m, replace=replace) for i, m in zip(self.indices.values(), ns) + samples = [ + self._selected_obj.loc[idx].sample( + n=n, frac=frac, replace=replace, weights=w, random_state=random_state + ) + for idx, w in zip(self.indices.values(), ws) ] - cons = self._selected_obj.index._constructor - idx = cons(np.concatenate(idx_list)) - - return self._selected_obj.loc[idx] + return concat(samples, axis=self.axis) GroupBy._add_numeric_operations() diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 5cc2a86313831..59659eb1309be 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -1,6 +1,6 @@ import pytest -from pandas import DataFrame, Series +from pandas import DataFrame, Index, Series import pandas._testing as tm @@ -31,7 +31,7 @@ def test_groupby_sample_unbalanced_groups_shape(): def test_groupby_sample_n_and_frac_raises(): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) - msg = "Please enter a value for frac or n but not both" + msg = "Please enter a value for `frac` OR `n`, not both" with pytest.raises(ValueError, match=msg): df.groupby("a").sample(n=1, frac=1.0) @@ -42,7 +42,9 @@ def test_groupby_sample_n_and_frac_raises(): def test_groupby_sample_frac_gt_one_without_replacement_raises(): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) - msg = "replace must be set to True when frac > 1" + msg = ( + "Replace has to be set to `True` when upsampling " "the population `frac` > 1." + ) with pytest.raises(ValueError, match=msg): df.groupby("a").sample(frac=1.5, replace=False) @@ -54,7 +56,11 @@ def test_groupby_sample_frac_gt_one_without_replacement_raises(): @pytest.mark.parametrize("n", [-1, 1.5]) def test_groupby_sample_invalid_n(n): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) - msg = "Only non-negative integers accepted as n values" + + if n < 0: + msg = "Please provide positive value" + else: + msg = "Only integers accepted as `n` values" with pytest.raises(ValueError, match=msg): df.groupby("a").sample(n=n) @@ -85,3 +91,17 @@ def test_groupby_sample_without_n_or_frac(): result = df.groupby("a")["b"].sample(n=None, frac=None) expected = Series([1, 1], name="b", index=result.index) tm.assert_series_equal(result, expected) + + +def test_groupby_sample_with_weights(): + df = DataFrame({"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=Index([0, 1, 2, 3])) + + result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = DataFrame( + {"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=Index([0, 0, 2, 2]) + ) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) + expected = Series([1, 1, 1, 1], name="b", index=Index([0, 0, 2, 2])) + tm.assert_series_equal(result, expected) From 07dacf2cdc5ba4114a887a4dccfa9e95de9eb112 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 14:20:12 -0500 Subject: [PATCH 08/34] black --- pandas/tests/groupby/test_sample.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 59659eb1309be..4606852559875 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -42,9 +42,7 @@ def test_groupby_sample_n_and_frac_raises(): def test_groupby_sample_frac_gt_one_without_replacement_raises(): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) - msg = ( - "Replace has to be set to `True` when upsampling " "the population `frac` > 1." - ) + msg = "Replace has to be set to `True` when upsampling the population `frac` > 1." with pytest.raises(ValueError, match=msg): df.groupby("a").sample(frac=1.5, replace=False) From 29356455668366aaa036b7d08f6f9160759113bb Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 15:55:45 -0500 Subject: [PATCH 09/34] Add doc examples --- pandas/core/groupby/groupby.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ff5eb111a6c2d..158596a3ee463 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2675,6 +2675,30 @@ def sample( DataFrame.sample: Generate random samples from a DataFrame object. numpy.random.choice: Generates a random sample from a given 1-D numpy array. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} + ... ) + >>> df + a b + 0 red 0 + 1 red 1 + 2 blue 2 + 3 blue 3 + 4 black 4 + 5 black 5 + >>> df.groupby("a").sample(n=1, random_state=1) + a b + 4 black 4 + 2 blue 2 + 0 red 0 + >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) + 5 5 + 3 3 + 1 1 + Name: b, dtype: int64 """ from pandas.core.reshape.concat import concat From 3e159a8b8b0a765df0483df9e691960e80d885af Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 15:58:12 -0500 Subject: [PATCH 10/34] Fixup --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 158596a3ee463..d142903c064b3 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2673,7 +2673,7 @@ def sample( See Also -------- DataFrame.sample: Generate random samples from a DataFrame object. - numpy.random.choice: Generates a random sample from a given 1-D numpy + numpy.random.choice: Generate a random sample from a given 1-D numpy array. Examples From 2397c3acc58900851b7406ece3b1e8d808434d3f Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 16:01:38 -0500 Subject: [PATCH 11/34] Another fixup --- pandas/tests/groupby/test_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 4606852559875..a350bd7a2971e 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -52,7 +52,7 @@ def test_groupby_sample_frac_gt_one_without_replacement_raises(): @pytest.mark.parametrize("n", [-1, 1.5]) -def test_groupby_sample_invalid_n(n): +def test_groupby_sample_invalid_n_raises(n): df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) if n < 0: From 8c3dfd8936cad465f7e08d7c3d9bc8782e43131b Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 8 May 2020 16:20:12 -0500 Subject: [PATCH 12/34] Edit tests --- pandas/tests/groupby/test_sample.py | 46 ++++++++++++++++------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index a350bd7a2971e..fca938406468d 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -6,31 +6,35 @@ @pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)]) def test_groupby_sample_balanced_groups_shape(n, frac): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) result = df.groupby("a").sample(n=n, frac=frac) - expected = DataFrame({"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=result.index) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(n=n, frac=frac) - expected = Series([1] * 4, name="b", index=result.index) + expected = Series(values, name="b", index=result.index) tm.assert_series_equal(result, expected) def test_groupby_sample_unbalanced_groups_shape(): - df = DataFrame({"a": [1] * 10 + [2] * 20, "b": [1] * 30}) + values = [1] * 10 + [2] * 20 + df = DataFrame({"a": values, "b": values}) result = df.groupby("a").sample(n=5) - expected = DataFrame({"a": [1] * 5 + [2] * 5, "b": [1] * 10}, index=result.index) + values = [1] * 5 + [2] * 5 + expected = DataFrame({"a": values, "b": values}, index=result.index) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(n=5) - expected = Series([1] * 10, name="b", index=result.index) + expected = Series(values, name="b", index=result.index) tm.assert_series_equal(result, expected) def test_groupby_sample_n_and_frac_raises(): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + df = DataFrame({"a": [1, 2], "b": [1, 2]}) msg = "Please enter a value for `frac` OR `n`, not both" with pytest.raises(ValueError, match=msg): @@ -41,7 +45,7 @@ def test_groupby_sample_n_and_frac_raises(): def test_groupby_sample_frac_gt_one_without_replacement_raises(): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + df = DataFrame({"a": [1, 2], "b": [1, 2]}) msg = "Replace has to be set to `True` when upsampling the population `frac` > 1." with pytest.raises(ValueError, match=msg): @@ -53,7 +57,7 @@ def test_groupby_sample_frac_gt_one_without_replacement_raises(): @pytest.mark.parametrize("n", [-1, 1.5]) def test_groupby_sample_invalid_n_raises(n): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + df = DataFrame({"a": [1, 2], "b": [1, 2]}) if n < 0: msg = "Please provide positive value" @@ -68,38 +72,40 @@ def test_groupby_sample_invalid_n_raises(n): def test_groupby_sample_oversample(): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) result = df.groupby("a").sample(frac=2.0, replace=True) - expected = DataFrame({"a": [1] * 20 + [2] * 20, "b": [1] * 40}, index=result.index) + values = [1] * 20 + [2] * 20 + expected = DataFrame({"a": values, "b": values}, index=result.index) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(frac=2.0, replace=True) - expected = Series([1] * 40, name="b", index=result.index) + expected = Series(values, name="b", index=result.index) tm.assert_series_equal(result, expected) def test_groupby_sample_without_n_or_frac(): - df = DataFrame({"a": [1] * 10 + [2] * 10, "b": [1] * 20}) + values = [1] * 10 + [2] * 10 + df = DataFrame({"a": values, "b": values}) result = df.groupby("a").sample(n=None, frac=None) - expected = DataFrame({"a": [1, 2], "b": [1, 1]}, index=result.index) + expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(n=None, frac=None) - expected = Series([1, 1], name="b", index=result.index) + expected = Series([1, 2], name="b", index=result.index) tm.assert_series_equal(result, expected) def test_groupby_sample_with_weights(): - df = DataFrame({"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=Index([0, 1, 2, 3])) + values = [1] * 2 + [2] * 2 + df = DataFrame({"a": values, "b": values}, index=Index([0, 1, 2, 3])) result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) - expected = DataFrame( - {"a": [1] * 2 + [2] * 2, "b": [1] * 4}, index=Index([0, 0, 2, 2]) - ) + expected = DataFrame({"a": values, "b": values}, index=Index([0, 0, 2, 2])) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) - expected = Series([1, 1, 1, 1], name="b", index=Index([0, 0, 2, 2])) + expected = Series(values, name="b", index=Index([0, 0, 2, 2])) tm.assert_series_equal(result, expected) From e6579d33e193db761fa8cc17f396ef6981be1374 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 11 May 2020 11:59:35 -0500 Subject: [PATCH 13/34] Update docstring --- pandas/core/generic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 792e5a1228fe6..7eb950e4df4da 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4858,6 +4858,10 @@ def sample( See Also -------- + DataFrameGroupBy.sample: Generates random samples from each group of a + DataFrame object. + SeriesGroupBy.sample: Generates random samples from each group of a + Series object. numpy.random.choice: Generates a random sample from a given 1-D numpy array. From 37037c2ede92f1b95c3184415c61e4c059aa2d03 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 11 May 2020 19:13:17 -0500 Subject: [PATCH 14/34] Don't use selected_obj.index --- pandas/core/groupby/groupby.py | 6 +++--- pandas/tests/groupby/test_sample.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index fbcc5d9b49d21..8736d34a84d01 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2697,16 +2697,16 @@ def sample( from pandas.core.reshape.concat import concat if weights is not None: - weights = Series(weights, index=self._selected_obj.index) + weights = Series(weights) ws = [weights[idx] for idx in self.indices.values()] else: ws = [None] * self.ngroups samples = [ - self._selected_obj.loc[idx].sample( + self.get_group(k).sample( n=n, frac=frac, replace=replace, weights=w, random_state=random_state ) - for idx, w in zip(self.indices.values(), ws) + for k, w in zip(self.groups.keys(), ws) ] return concat(samples, axis=self.axis) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index fca938406468d..480eba35f2762 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -33,6 +33,20 @@ def test_groupby_sample_unbalanced_groups_shape(): tm.assert_series_equal(result, expected) +def test_groupby_sample_index_value_spans_groups(): + values = [1] * 3 + [2] * 3 + df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2]) + + result = df.groupby("a").sample(n=2) + values = [1] * 2 + [2] * 2 + expected = DataFrame({"a": values, "b": values}, index=result.index) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].sample(n=2) + expected = Series(values, name="b", index=result.index) + tm.assert_series_equal(result, expected) + + def test_groupby_sample_n_and_frac_raises(): df = DataFrame({"a": [1, 2], "b": [1, 2]}) msg = "Please enter a value for `frac` OR `n`, not both" From 1d3c4d223319d9385ad6f0ed5b7431c8fb751fc8 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Thu, 14 May 2020 13:09:40 -0500 Subject: [PATCH 15/34] Iterate over self --- pandas/core/groupby/groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index e53f1eaeef978..b133be2a4610f 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2683,6 +2683,7 @@ def sample( Examples -------- + >>> df = pd.DataFrame( ... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} ... ) @@ -2714,10 +2715,10 @@ def sample( ws = [None] * self.ngroups samples = [ - self.get_group(k).sample( + obj.sample( n=n, frac=frac, replace=replace, weights=w, random_state=random_state ) - for k, w in zip(self.groups.keys(), ws) + for (_, obj), w in zip(self, ws) ] return concat(samples, axis=self.axis) From c2e1615031dde50232fc327d5b3f8d5d665f2c44 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Thu, 21 May 2020 16:04:28 -0500 Subject: [PATCH 16/34] Sequence --- pandas/core/groupby/groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index b133be2a4610f..5ec1a07b0ca02 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -23,6 +23,7 @@ class providing the base-class of operations. List, Mapping, Optional, + Sequence, Tuple, Type, TypeVar, @@ -2640,7 +2641,7 @@ def sample( n: Optional[int] = None, frac: Optional[float] = None, replace: bool = False, - weights=None, + weights: Optional[Union[Sequence, Series]] = None, random_state=None, ): """ From 1f733d6420f67ff4e3c1bfcc6861d7153259e7a9 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 25 May 2020 17:31:15 -0500 Subject: [PATCH 17/34] Fix examples --- pandas/core/groupby/groupby.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d9a50e3d191d5..273fc7a8e0696 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2728,7 +2728,6 @@ def sample( Examples -------- - >>> df = pd.DataFrame( ... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} ... ) @@ -2740,11 +2739,18 @@ def sample( 3 blue 3 4 black 4 5 black 5 + + Select one row at random for each distinct value in column a. The + `random_state` argument can be used to guarantee reproducibility: + >>> df.groupby("a").sample(n=1, random_state=1) a b 4 black 4 2 blue 2 0 red 0 + + Set `frac` to sample fixed proportions rather than counts: + >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) 5 5 3 3 From 0369d22288d48cc70c2298cd5d205bc975847649 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 13:23:34 -0500 Subject: [PATCH 18/34] Add random_state tests --- pandas/tests/groupby/test_sample.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 480eba35f2762..175b61c515462 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -1,3 +1,6 @@ +from copy import copy + +import numpy as np import pytest from pandas import DataFrame, Index, Series @@ -123,3 +126,22 @@ def test_groupby_sample_with_weights(): result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) expected = Series(values, name="b", index=Index([0, 0, 2, 2])) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "random_state", + [ + 0, + np.array([0, 1, 2]), + np.random.RandomState(0), + np.random.PCG64(0), + np.random.MT19937(0), + ], +) +def test_groupby_sample_using_random_state(random_state): + df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.random.random(100)}) + rs = copy(random_state) + expected = df.groupby("a").sample(frac=0.5, random_state=rs) + rs = copy(random_state) + result = df.groupby("a").sample(frac=0.5, random_state=rs) + tm.assert_frame_equal(result, expected) From 279cc3c9d5f7ea20f0a2e35da288869dec70fd91 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 13:29:16 -0500 Subject: [PATCH 19/34] Copy less --- pandas/tests/groupby/test_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 175b61c515462..a5ea0dda942f7 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -142,6 +142,6 @@ def test_groupby_sample_using_random_state(random_state): df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.random.random(100)}) rs = copy(random_state) expected = df.groupby("a").sample(frac=0.5, random_state=rs) - rs = copy(random_state) + rs = random_state result = df.groupby("a").sample(frac=0.5, random_state=rs) tm.assert_frame_equal(result, expected) From daf278b3f1ad297a929203e2705125a705f4bdc7 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 13:33:51 -0500 Subject: [PATCH 20/34] random -> arange --- pandas/tests/groupby/test_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index a5ea0dda942f7..9277c1d2278f3 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -139,7 +139,7 @@ def test_groupby_sample_with_weights(): ], ) def test_groupby_sample_using_random_state(random_state): - df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.random.random(100)}) + df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) rs = copy(random_state) expected = df.groupby("a").sample(frac=0.5, random_state=rs) rs = random_state From 88ef72cbd335a374651d1f5f11571c6ab17c5107 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 16:44:48 -0500 Subject: [PATCH 21/34] Skip for numpy version --- pandas/tests/groupby/test_sample.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 9277c1d2278f3..71217b8760000 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + from pandas import DataFrame, Index, Series import pandas._testing as tm @@ -129,14 +131,7 @@ def test_groupby_sample_with_weights(): @pytest.mark.parametrize( - "random_state", - [ - 0, - np.array([0, 1, 2]), - np.random.RandomState(0), - np.random.PCG64(0), - np.random.MT19937(0), - ], + "random_state", [0, np.array([0, 1, 2]), np.random.RandomState(0)], ) def test_groupby_sample_using_random_state(random_state): df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) @@ -145,3 +140,16 @@ def test_groupby_sample_using_random_state(random_state): rs = random_state result = df.groupby("a").sample(frac=0.5, random_state=rs) tm.assert_frame_equal(result, expected) + + +@pytest.mark.skipif(_np_version_under1p17, reason="Skipping for numpy version < 1.17") +@pytest.mark.parametrize( + "random_state", [np.random.PCG64(0), np.random.MT19937(0)], +) +def test_groupby_sample_using_bitgenerator(random_state): + df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) + rs = copy(random_state) + expected = df.groupby("a").sample(frac=0.5, random_state=rs) + rs = random_state + result = df.groupby("a").sample(frac=0.5, random_state=rs) + tm.assert_frame_equal(result, expected) From fb55e084a8b24154a7dfbf36d1c05078765f2c05 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 17:33:28 -0500 Subject: [PATCH 22/34] Revert "Skip for numpy version" This reverts commit 88ef72cbd335a374651d1f5f11571c6ab17c5107. --- pandas/tests/groupby/test_sample.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 71217b8760000..9277c1d2278f3 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas.compat.numpy import _np_version_under1p17 - from pandas import DataFrame, Index, Series import pandas._testing as tm @@ -131,7 +129,14 @@ def test_groupby_sample_with_weights(): @pytest.mark.parametrize( - "random_state", [0, np.array([0, 1, 2]), np.random.RandomState(0)], + "random_state", + [ + 0, + np.array([0, 1, 2]), + np.random.RandomState(0), + np.random.PCG64(0), + np.random.MT19937(0), + ], ) def test_groupby_sample_using_random_state(random_state): df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) @@ -140,16 +145,3 @@ def test_groupby_sample_using_random_state(random_state): rs = random_state result = df.groupby("a").sample(frac=0.5, random_state=rs) tm.assert_frame_equal(result, expected) - - -@pytest.mark.skipif(_np_version_under1p17, reason="Skipping for numpy version < 1.17") -@pytest.mark.parametrize( - "random_state", [np.random.PCG64(0), np.random.MT19937(0)], -) -def test_groupby_sample_using_bitgenerator(random_state): - df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) - rs = copy(random_state) - expected = df.groupby("a").sample(frac=0.5, random_state=rs) - rs = random_state - result = df.groupby("a").sample(frac=0.5, random_state=rs) - tm.assert_frame_equal(result, expected) From 1a3016a1acbaa618aac6c1b719bee90005459137 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 18:23:18 -0500 Subject: [PATCH 23/34] Try again --- pandas/tests/groupby/test_sample.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 9277c1d2278f3..1c62fc3fd19b9 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + from pandas import DataFrame, Index, Series import pandas._testing as tm @@ -128,20 +130,17 @@ def test_groupby_sample_with_weights(): tm.assert_series_equal(result, expected) -@pytest.mark.parametrize( - "random_state", - [ - 0, - np.array([0, 1, 2]), - np.random.RandomState(0), - np.random.PCG64(0), - np.random.MT19937(0), - ], -) +@pytest.mark.parametrize("random_state", [0, np.array([0, 1, 2])]) def test_groupby_sample_using_random_state(random_state): df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) - rs = copy(random_state) - expected = df.groupby("a").sample(frac=0.5, random_state=rs) - rs = random_state - result = df.groupby("a").sample(frac=0.5, random_state=rs) + expected = df.groupby("a").sample(frac=0.5, random_state=random_state) + result = df.groupby("a").sample(frac=0.5, random_state=random_state) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.skipif(_np_version_under1p17, reason="Numpy version < 1.17") +def test_groupby_sample_using_bitgenerator(): + df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) + expected = df.groupby("a").sample(frac=0.5, random_state=np.random.MT19937(0)) + result = df.groupby("a").sample(frac=0.5, random_state=np.random.MT19937(0)) tm.assert_frame_equal(result, expected) From c136c1fbcbba324185e31028b1b6a8ef625f8030 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Mon, 1 Jun 2020 19:00:43 -0500 Subject: [PATCH 24/34] Fix --- pandas/tests/groupby/test_sample.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 1c62fc3fd19b9..2161f9d3376f2 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -1,5 +1,3 @@ -from copy import copy - import numpy as np import pytest From 372da0e3c471054337151f33088631424ffafee6 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Tue, 2 Jun 2020 18:06:13 -0500 Subject: [PATCH 25/34] Delete --- pandas/tests/groupby/test_sample.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 2161f9d3376f2..480eba35f2762 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -1,8 +1,5 @@ -import numpy as np import pytest -from pandas.compat.numpy import _np_version_under1p17 - from pandas import DataFrame, Index, Series import pandas._testing as tm @@ -126,19 +123,3 @@ def test_groupby_sample_with_weights(): result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) expected = Series(values, name="b", index=Index([0, 0, 2, 2])) tm.assert_series_equal(result, expected) - - -@pytest.mark.parametrize("random_state", [0, np.array([0, 1, 2])]) -def test_groupby_sample_using_random_state(random_state): - df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) - expected = df.groupby("a").sample(frac=0.5, random_state=random_state) - result = df.groupby("a").sample(frac=0.5, random_state=random_state) - tm.assert_frame_equal(result, expected) - - -@pytest.mark.skipif(_np_version_under1p17, reason="Numpy version < 1.17") -def test_groupby_sample_using_bitgenerator(): - df = DataFrame({"a": [1] * 50 + [2] * 50, "b": np.arange(100)}) - expected = df.groupby("a").sample(frac=0.5, random_state=np.random.MT19937(0)) - result = df.groupby("a").sample(frac=0.5, random_state=np.random.MT19937(0)) - tm.assert_frame_equal(result, expected) From b1bf65f0f9724631a497492e0d50e5deefea712b Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Tue, 2 Jun 2020 18:12:37 -0500 Subject: [PATCH 26/34] random_state --- pandas/core/groupby/groupby.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 868294a8ed1f9..734ead8666420 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2764,14 +2764,14 @@ def sample( a b 4 black 4 2 blue 2 - 0 red 0 + 0 red 1 Set `frac` to sample fixed proportions rather than counts: >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) 5 5 - 3 3 - 1 1 + 3 2 + 1 0 Name: b, dtype: int64 """ from pandas.core.reshape.concat import concat @@ -2782,6 +2782,9 @@ def sample( else: ws = [None] * self.ngroups + if random_state: + random_state = com.random_state(random_state) + samples = [ obj.sample( n=n, frac=frac, replace=replace, weights=w, random_state=random_state From 48eea9765f25816be065f8808381de8a08c3dddf Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Tue, 2 Jun 2020 18:43:11 -0500 Subject: [PATCH 27/34] Doc --- pandas/core/groupby/groupby.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 734ead8666420..d4bb5351d3ceb 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2764,14 +2764,14 @@ def sample( a b 4 black 4 2 blue 2 - 0 red 1 + 1 red 1 Set `frac` to sample fixed proportions rather than counts: >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2) 5 5 - 3 2 - 1 0 + 2 2 + 0 0 Name: b, dtype: int64 """ from pandas.core.reshape.concat import concat From b07b377f64eaf29c33726edd3044e0e71f32dc51 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Thu, 4 Jun 2020 08:32:13 -0500 Subject: [PATCH 28/34] not None --- pandas/core/groupby/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d4bb5351d3ceb..0afe3f905fa28 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2782,7 +2782,7 @@ def sample( else: ws = [None] * self.ngroups - if random_state: + if random_state is not None: random_state = com.random_state(random_state) samples = [ From 62f7a15207a8ab34e6ff81048884e07c8b0fd4d9 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Thu, 4 Jun 2020 08:37:31 -0500 Subject: [PATCH 29/34] doc --- doc/source/whatsnew/v1.1.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index e035a21e3a24f..e3bc4fe23bea5 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -275,7 +275,7 @@ Other enhancements such as ``dict`` and ``list``, mirroring the behavior of :meth:`DataFrame.update` (:issue:`33215`) - :meth:`~pandas.core.groupby.GroupBy.transform` and :meth:`~pandas.core.groupby.GroupBy.aggregate` has gained ``engine`` and ``engine_kwargs`` arguments that supports executing functions with ``Numba`` (:issue:`32854`, :issue:`33388`) - :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`) -- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`) +- :class:`~pandas.core.groupby.generic.DataFrameGroupBy` and :class:`~pandas.core.groupby.generic.SeriesGroupBy` now implement the ``sample`` method for doing random sampling within groups (:issue:`31775`) - :meth:`DataFrame.to_numpy` now supports the ``na_value`` keyword to control the NA sentinel in the output array (:issue:`33820`) - The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals` method, similarly to :meth:`Series.equals` (:issue:`27081`). From 68d8d4a911fa0e615cc163abfb4de9c7b075a0fb Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 5 Jun 2020 11:24:06 -0500 Subject: [PATCH 30/34] doc --- pandas/core/groupby/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 0afe3f905fa28..d64a5569b410e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2724,8 +2724,8 @@ def sample( weights : list-like, optional Default None results in equal probability weighting. If passed a list-like then values must have the same length as - the underlying object and will be used as sampling probabilities - after normalization within each group. + the underlying DataFrame or Series object and will be used as + sampling probabilities after normalization within each group. random_state : int, array-like, BitGenerator, np.random.RandomState, optional If int, array-like, or BitGenerator (NumPy>=1.17), seed for random number generator From 97034ae2b96e652594553905966d7f6c7fd7ddbd Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 5 Jun 2020 11:29:45 -0500 Subject: [PATCH 31/34] Add weights example --- pandas/core/groupby/groupby.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d64a5569b410e..f04949810d3ab 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2773,6 +2773,18 @@ def sample( 2 2 0 0 Name: b, dtype: int64 + + Control sample probabilities within groups by setting weights: + + >>> df.groupby("a").sample( + ... n=1, + ... weights=[1, 1, 1, 0, 0, 1], + ... random_state=1, + ... ) + a b + 5 black 5 + 2 blue 2 + 0 red 0 """ from pandas.core.reshape.concat import concat From ad0bd613f82fb1c8de1df908da40782e71e147e0 Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 5 Jun 2020 14:47:13 -0500 Subject: [PATCH 32/34] Fix weights index and adjust test --- pandas/core/groupby/groupby.py | 2 +- pandas/tests/groupby/test_sample.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f04949810d3ab..022b625691df8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2789,7 +2789,7 @@ def sample( from pandas.core.reshape.concat import concat if weights is not None: - weights = Series(weights) + weights = Series(weights, index=self._selected_obj.index) ws = [weights[idx] for idx in self.indices.values()] else: ws = [None] * self.ngroups diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 480eba35f2762..412e3e8f732de 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -114,12 +114,12 @@ def test_groupby_sample_without_n_or_frac(): def test_groupby_sample_with_weights(): values = [1] * 2 + [2] * 2 - df = DataFrame({"a": values, "b": values}, index=Index([0, 1, 2, 3])) + df = DataFrame({"a": values, "b": values}, index=Index(["w", "x", "y", "z"])) result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0]) - expected = DataFrame({"a": values, "b": values}, index=Index([0, 0, 2, 2])) + expected = DataFrame({"a": values, "b": values}, index=Index(["w", "w", "y", "y"])) tm.assert_frame_equal(result, expected) result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0]) - expected = Series(values, name="b", index=Index([0, 0, 2, 2])) + expected = Series(values, name="b", index=Index(["w", "w", "y", "y"])) tm.assert_series_equal(result, expected) From 05a1ba58c6d1d3241fc8970abb4993f5a6dbaaff Mon Sep 17 00:00:00 2001 From: Daniel Saxton Date: Fri, 5 Jun 2020 14:54:46 -0500 Subject: [PATCH 33/34] Update docstring --- pandas/core/groupby/groupby.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 022b625691df8..741f8712bf23a 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -2715,8 +2715,9 @@ def sample( Parameters ---------- n : int, optional - Number of items to return. Cannot be used with `frac`. - Default = 1 if `frac` is None. + Number of items to return for each group. Cannot be used with + `frac` and must be no larger than the smallest group unless + `replace` is True. Default is one if `frac` is None. frac : float, optional Fraction of items to return. Cannot be used with `n`. replace : bool, default False @@ -2726,6 +2727,8 @@ def sample( If passed a list-like then values must have the same length as the underlying DataFrame or Series object and will be used as sampling probabilities after normalization within each group. + Values must be non-negative with at least one positive element + within each group. random_state : int, array-like, BitGenerator, np.random.RandomState, optional If int, array-like, or BitGenerator (NumPy>=1.17), seed for random number generator From 56a49a033f85f566b87d3994d810c0afd38d9f70 Mon Sep 17 00:00:00 2001 From: DANIEL SAXTON Date: Tue, 9 Jun 2020 22:31:36 -0500 Subject: [PATCH 34/34] Update doc --- doc/source/reference/groupby.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/reference/groupby.rst b/doc/source/reference/groupby.rst index 5f6bef2579d27..76cb53559f334 100644 --- a/doc/source/reference/groupby.rst +++ b/doc/source/reference/groupby.rst @@ -116,6 +116,7 @@ application to columns of a specific data type. DataFrameGroupBy.quantile DataFrameGroupBy.rank DataFrameGroupBy.resample + DataFrameGroupBy.sample DataFrameGroupBy.shift DataFrameGroupBy.size DataFrameGroupBy.skew