diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index 1648985a56b91..6ca951e946bad 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -832,4 +832,18 @@ def function(values): self.grouper.agg(function, engine="cython") +class Sample: + def setup(self): + N = 10 ** 3 + self.df = DataFrame({"a": np.zeros(N)}) + self.groups = np.arange(0, N) + self.weights = np.ones(N) + + def time_sample(self): + self.df.groupby(self.groups).sample(n=1) + + def time_sample_weights(self): + self.df.groupby(self.groups).sample(n=1, weights=self.weights) + + from .pandas_vb_common import setup # noqa: F401 isort:skip diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 81545ada63ce5..e751edece7a27 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -105,7 +105,7 @@ Deprecations Performance improvements ~~~~~~~~~~~~~~~~~~~~~~~~ -- +- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 82895ab9eb67a..f2497c6e65967 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -137,6 +137,7 @@ from pandas.core.missing import find_valid_index from pandas.core.ops import align_method_FRAME from pandas.core.reshape.concat import concat +import pandas.core.sample as sample from pandas.core.shared_docs import _shared_docs from pandas.core.sorting import get_indexer_indexer from pandas.core.window import ( @@ -5143,7 +5144,7 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries: @final def sample( self: FrameOrSeries, - n=None, + n: int | None = None, frac: float | None = None, replace: bool_t = False, weights=None, @@ -5270,92 +5271,22 @@ def sample( axis = self._stat_axis_number axis = self._get_axis_number(axis) - axis_length = self.shape[axis] + obj_len = self.shape[axis] # Process random_state argument rs = com.random_state(random_state) - # Check weights for compliance - if weights is not None: - - # If a series, align with frame - if isinstance(weights, ABCSeries): - weights = weights.reindex(self.axes[axis]) - - # Strings acceptable if a dataframe and axis = 0 - if isinstance(weights, str): - if isinstance(self, ABCDataFrame): - if axis == 0: - try: - weights = self[weights] - except KeyError as err: - raise KeyError( - "String passed to weights not a valid column" - ) from err - else: - raise ValueError( - "Strings can only be passed to " - "weights when sampling from rows on " - "a DataFrame" - ) - else: - raise ValueError( - "Strings cannot be passed as weights " - "when sampling from a Series." - ) - - if isinstance(self, ABCSeries): - func = self._constructor - else: - func = self._constructor_sliced - weights = func(weights, dtype="float64") + size = sample.process_sampling_size(n, frac, replace) + if size is None: + assert frac is not None + size = round(frac * obj_len) - if len(weights) != axis_length: - raise ValueError( - "Weights and axis to be sampled must be of same length" - ) - - if (weights == np.inf).any() or (weights == -np.inf).any(): - raise ValueError("weight vector may not include `inf` values") - - if (weights < 0).any(): - raise ValueError("weight vector many not include negative values") - - # If has nan, set to zero. - weights = weights.fillna(0) - - # Renormalize if don't sum to 1 - if weights.sum() != 1: - if weights.sum() != 0: - weights = weights / weights.sum() - else: - raise ValueError("Invalid weights: weights sum to zero") - - weights = weights._values + if weights is not None: + weights = sample.preprocess_weights(self, weights, axis) - # If no frac or n, default to n=1. - if n is None and frac is None: - n = 1 - elif frac is not None and frac > 1 and not replace: - raise ValueError( - "Replace has to be set to `True` when " - "upsampling the population `frac` > 1." - ) - elif frac is None and n % 1 != 0: - raise ValueError("Only integers accepted as `n` values") - elif n is None and frac is not None: - n = round(frac * axis_length) - elif frac is not None: - raise ValueError("Please enter a value for `frac` OR `n`, not both") - - # Check for negative sizes - if n < 0: - raise ValueError( - "A negative number of rows requested. Please provide positive value." - ) + sampled_indices = sample.sample(obj_len, size, replace, weights, rs) + result = self.take(sampled_indices, axis=axis) - locs = rs.choice(axis_length, size=n, replace=replace, p=weights) - result = self.take(locs, axis=axis) if ignore_index: result.index = ibase.default_index(len(result)) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 7fd5f2d52d23c..8fb50db2e33f2 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -102,6 +102,7 @@ class providing the base-class of operations. MultiIndex, ) from pandas.core.internals.blocks import ensure_block_shape +import pandas.core.sample as sample from pandas.core.series import Series from pandas.core.sorting import get_group_index_sorter from pandas.core.util.numba_ import ( @@ -3270,26 +3271,37 @@ def sample( 2 blue 2 0 red 0 """ - from pandas.core.reshape.concat import concat - + size = sample.process_sampling_size(n, frac, replace) if weights is not None: - weights = Series(weights, index=self._selected_obj.index) - ws = [weights.iloc[idx] for idx in self.indices.values()] - else: - ws = [None] * self.ngroups + weights_arr = sample.preprocess_weights( + self._selected_obj, weights, axis=self.axis + ) - if random_state is not None: - random_state = com.random_state(random_state) + random_state = com.random_state(random_state) group_iterator = self.grouper.get_iterator(self._selected_obj, self.axis) - samples = [ - obj.sample( - n=n, frac=frac, replace=replace, weights=w, random_state=random_state + + sampled_indices = [] + for labels, obj in group_iterator: + grp_indices = self.indices[labels] + group_size = len(grp_indices) + if size is not None: + sample_size = size + else: + assert frac is not None + sample_size = round(frac * group_size) + + grp_sample = sample.sample( + group_size, + size=sample_size, + replace=replace, + weights=None if weights is None else weights_arr[grp_indices], + random_state=random_state, ) - for (_, obj), w in zip(group_iterator, ws) - ] + sampled_indices.append(grp_indices[grp_sample]) - return concat(samples, axis=self.axis) + sampled_indices = np.concatenate(sampled_indices) + return self._selected_obj.take(sampled_indices, axis=self.axis) @doc(GroupBy) diff --git a/pandas/core/sample.py b/pandas/core/sample.py new file mode 100644 index 0000000000000..4798f385d523c --- /dev/null +++ b/pandas/core/sample.py @@ -0,0 +1,144 @@ +""" +Module containing utilities for NDFrame.sample() and .GroupBy.sample() +""" +from __future__ import annotations + +import numpy as np + +from pandas._libs import lib +from pandas._typing import FrameOrSeries + +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) + + +def preprocess_weights(obj: FrameOrSeries, weights, axis: int) -> np.ndarray: + """ + Process and validate the `weights` argument to `NDFrame.sample` and + `.GroupBy.sample`. + + Returns `weights` as an ndarray[np.float64], validated except for normalizing + weights (because that must be done groupwise in groupby sampling). + """ + # If a series, align with frame + if isinstance(weights, ABCSeries): + weights = weights.reindex(obj.axes[axis]) + + # Strings acceptable if a dataframe and axis = 0 + if isinstance(weights, str): + if isinstance(obj, ABCDataFrame): + if axis == 0: + try: + weights = obj[weights] + except KeyError as err: + raise KeyError( + "String passed to weights not a valid column" + ) from err + else: + raise ValueError( + "Strings can only be passed to " + "weights when sampling from rows on " + "a DataFrame" + ) + else: + raise ValueError( + "Strings cannot be passed as weights when sampling from a Series." + ) + + if isinstance(obj, ABCSeries): + func = obj._constructor + else: + func = obj._constructor_sliced + + weights = func(weights, dtype="float64")._values + + if len(weights) != obj.shape[axis]: + raise ValueError("Weights and axis to be sampled must be of same length") + + if lib.has_infs(weights): + raise ValueError("weight vector may not include `inf` values") + + if (weights < 0).any(): + raise ValueError("weight vector many not include negative values") + + weights[np.isnan(weights)] = 0 + return weights + + +def process_sampling_size( + n: int | None, frac: float | None, replace: bool +) -> int | None: + """ + Process and validate the `n` and `frac` arguments to `NDFrame.sample` and + `.GroupBy.sample`. + + Returns None if `frac` should be used (variable sampling sizes), otherwise returns + the constant sampling size. + """ + # If no frac or n, default to n=1. + if n is None and frac is None: + n = 1 + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + elif n is not None: + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide `n` >= 0." + ) + if n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + else: + assert frac is not None # for mypy + if frac > 1 and not replace: + raise ValueError( + "Replace has to be set to `True` when " + "upsampling the population `frac` > 1." + ) + if frac < 0: + raise ValueError( + "A negative number of rows requested. Please provide `frac` >= 0." + ) + + return n + + +def sample( + obj_len: int, + size: int, + replace: bool, + weights: np.ndarray | None, + random_state: np.random.RandomState, +) -> np.ndarray: + """ + Randomly sample `size` indices in `np.arange(obj_len)` + + Parameters + ---------- + obj_len : int + The length of the indices being considered + size : int + The number of values to choose + replace : bool + Allow or disallow sampling of the same row more than once. + weights : np.ndarray[np.float64] or None + If None, equal probability weighting, otherwise weights according + to the vector normalized + random_state: np.random.RandomState + State used for the random sampling + + Returns + ------- + np.ndarray[np.intp] + """ + if weights is not None: + weight_sum = weights.sum() + if weight_sum != 0: + weights = weights / weight_sum + else: + raise ValueError("Invalid weights: weights sum to zero") + + return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype( + np.intp, copy=False + ) diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index 604788ba91633..fc90bcbf5fbdc 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -84,10 +84,15 @@ def test_sample_wont_accept_n_and_frac(self, obj): obj.sample(n=3, frac=0.3) def test_sample_requires_positive_n_frac(self, obj): - msg = "A negative number of rows requested. Please provide positive value." - with pytest.raises(ValueError, match=msg): + with pytest.raises( + ValueError, + match="A negative number of rows requested. Please provide `n` >= 0", + ): obj.sample(n=-3) - with pytest.raises(ValueError, match=msg): + with pytest.raises( + ValueError, + match="A negative number of rows requested. Please provide `frac` >= 0", + ): obj.sample(frac=-0.3) def test_sample_requires_integer_n(self, obj): diff --git a/pandas/tests/groupby/test_sample.py b/pandas/tests/groupby/test_sample.py index 652a5fc1a3c34..9153fac0927c5 100644 --- a/pandas/tests/groupby/test_sample.py +++ b/pandas/tests/groupby/test_sample.py @@ -78,7 +78,7 @@ def test_groupby_sample_invalid_n_raises(n): df = DataFrame({"a": [1, 2], "b": [1, 2]}) if n < 0: - msg = "Please provide positive value" + msg = "A negative number of rows requested. Please provide `n` >= 0." else: msg = "Only integers accepted as `n` values"