Skip to content

PERF/REF: groupby sample #42233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 2, 2021
14 changes: 14 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,4 +832,18 @@ def function(values):
self.grouper.agg(function, engine="cython")


class Sample:
def setup(self):
N = 10 ** 3
self.df = DataFrame({"a": np.zeros(N)})
self.groups = np.arange(0, N)
self.weights = np.ones(N)

def time_sample(self):
self.df.groupby(self.groups).sample(n=1)

def time_sample_weights(self):
self.df.groupby(self.groups).sample(n=1, weights=self.weights)


from .pandas_vb_common import setup # noqa: F401 isort:skip
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Deprecations

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
-
- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`)
-

.. ---------------------------------------------------------------------------
Expand Down
91 changes: 11 additions & 80 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
from pandas.core.missing import find_valid_index
from pandas.core.ops import align_method_FRAME
from pandas.core.reshape.concat import concat
import pandas.core.sample as sample
from pandas.core.shared_docs import _shared_docs
from pandas.core.sorting import get_indexer_indexer
from pandas.core.window import (
Expand Down Expand Up @@ -5143,7 +5144,7 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries:
@final
def sample(
self: FrameOrSeries,
n=None,
n: int | None = None,
frac: float | None = None,
replace: bool_t = False,
weights=None,
Expand Down Expand Up @@ -5270,92 +5271,22 @@ def sample(
axis = self._stat_axis_number

axis = self._get_axis_number(axis)
axis_length = self.shape[axis]
obj_len = self.shape[axis]

# Process random_state argument
rs = com.random_state(random_state)

# Check weights for compliance
if weights is not None:

# If a series, align with frame
if isinstance(weights, ABCSeries):
weights = weights.reindex(self.axes[axis])

# Strings acceptable if a dataframe and axis = 0
if isinstance(weights, str):
if isinstance(self, ABCDataFrame):
if axis == 0:
try:
weights = self[weights]
except KeyError as err:
raise KeyError(
"String passed to weights not a valid column"
) from err
else:
raise ValueError(
"Strings can only be passed to "
"weights when sampling from rows on "
"a DataFrame"
)
else:
raise ValueError(
"Strings cannot be passed as weights "
"when sampling from a Series."
)

if isinstance(self, ABCSeries):
func = self._constructor
else:
func = self._constructor_sliced
weights = func(weights, dtype="float64")
size = sample.process_sampling_size(n, frac, replace)
if size is None:
assert frac is not None
size = round(frac * obj_len)

if len(weights) != axis_length:
raise ValueError(
"Weights and axis to be sampled must be of same length"
)

if (weights == np.inf).any() or (weights == -np.inf).any():
raise ValueError("weight vector may not include `inf` values")

if (weights < 0).any():
raise ValueError("weight vector many not include negative values")

# If has nan, set to zero.
weights = weights.fillna(0)

# Renormalize if don't sum to 1
if weights.sum() != 1:
if weights.sum() != 0:
weights = weights / weights.sum()
else:
raise ValueError("Invalid weights: weights sum to zero")

weights = weights._values
if weights is not None:
weights = sample.preprocess_weights(self, weights, axis)

# If no frac or n, default to n=1.
if n is None and frac is None:
n = 1
elif frac is not None and frac > 1 and not replace:
raise ValueError(
"Replace has to be set to `True` when "
"upsampling the population `frac` > 1."
)
elif frac is None and n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
elif n is None and frac is not None:
n = round(frac * axis_length)
elif frac is not None:
raise ValueError("Please enter a value for `frac` OR `n`, not both")

# Check for negative sizes
if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide positive value."
)
sampled_indices = sample.sample(obj_len, size, replace, weights, rs)
result = self.take(sampled_indices, axis=axis)

locs = rs.choice(axis_length, size=n, replace=replace, p=weights)
result = self.take(locs, axis=axis)
if ignore_index:
result.index = ibase.default_index(len(result))

Expand Down
40 changes: 26 additions & 14 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class providing the base-class of operations.
MultiIndex,
)
from pandas.core.internals.blocks import ensure_block_shape
import pandas.core.sample as sample
from pandas.core.series import Series
from pandas.core.sorting import get_group_index_sorter
from pandas.core.util.numba_ import (
Expand Down Expand Up @@ -3270,26 +3271,37 @@ def sample(
2 blue 2
0 red 0
"""
from pandas.core.reshape.concat import concat

size = sample.process_sampling_size(n, frac, replace)
if weights is not None:
weights = Series(weights, index=self._selected_obj.index)
ws = [weights.iloc[idx] for idx in self.indices.values()]
else:
ws = [None] * self.ngroups
weights_arr = sample.preprocess_weights(
self._selected_obj, weights, axis=self.axis
)

if random_state is not None:
random_state = com.random_state(random_state)
random_state = com.random_state(random_state)

group_iterator = self.grouper.get_iterator(self._selected_obj, self.axis)
samples = [
obj.sample(
n=n, frac=frac, replace=replace, weights=w, random_state=random_state

sampled_indices = []
for labels, obj in group_iterator:
grp_indices = self.indices[labels]
group_size = len(grp_indices)
if size is not None:
sample_size = size
else:
assert frac is not None
sample_size = round(frac * group_size)

grp_sample = sample.sample(
group_size,
size=sample_size,
replace=replace,
weights=None if weights is None else weights_arr[grp_indices],
random_state=random_state,
)
for (_, obj), w in zip(group_iterator, ws)
]
sampled_indices.append(grp_indices[grp_sample])

return concat(samples, axis=self.axis)
sampled_indices = np.concatenate(sampled_indices)
return self._selected_obj.take(sampled_indices, axis=self.axis)


@doc(GroupBy)
Expand Down
144 changes: 144 additions & 0 deletions pandas/core/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Module containing utilities for NDFrame.sample() and .GroupBy.sample()
"""
from __future__ import annotations

import numpy as np

from pandas._libs import lib
from pandas._typing import FrameOrSeries

from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
)


def preprocess_weights(obj: FrameOrSeries, weights, axis: int) -> np.ndarray:
"""
Process and validate the `weights` argument to `NDFrame.sample` and
`.GroupBy.sample`.

Returns `weights` as an ndarray[np.float64], validated except for normalizing
weights (because that must be done groupwise in groupby sampling).
"""
# If a series, align with frame
if isinstance(weights, ABCSeries):
weights = weights.reindex(obj.axes[axis])

# Strings acceptable if a dataframe and axis = 0
if isinstance(weights, str):
if isinstance(obj, ABCDataFrame):
if axis == 0:
try:
weights = obj[weights]
except KeyError as err:
raise KeyError(
"String passed to weights not a valid column"
) from err
else:
raise ValueError(
"Strings can only be passed to "
"weights when sampling from rows on "
"a DataFrame"
)
else:
raise ValueError(
"Strings cannot be passed as weights when sampling from a Series."
)

if isinstance(obj, ABCSeries):
func = obj._constructor
else:
func = obj._constructor_sliced

weights = func(weights, dtype="float64")._values

if len(weights) != obj.shape[axis]:
raise ValueError("Weights and axis to be sampled must be of same length")

if lib.has_infs(weights):
raise ValueError("weight vector may not include `inf` values")

if (weights < 0).any():
raise ValueError("weight vector many not include negative values")

weights[np.isnan(weights)] = 0
return weights


def process_sampling_size(
n: int | None, frac: float | None, replace: bool
) -> int | None:
"""
Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
`.GroupBy.sample`.

Returns None if `frac` should be used (variable sampling sizes), otherwise returns
the constant sampling size.
"""
# If no frac or n, default to n=1.
if n is None and frac is None:
n = 1
elif n is not None and frac is not None:
raise ValueError("Please enter a value for `frac` OR `n`, not both")
elif n is not None:
if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide `n` >= 0."
)
if n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
else:
assert frac is not None # for mypy
if frac > 1 and not replace:
raise ValueError(
"Replace has to be set to `True` when "
"upsampling the population `frac` > 1."
)
if frac < 0:
raise ValueError(
"A negative number of rows requested. Please provide `frac` >= 0."
)

return n


def sample(
obj_len: int,
size: int,
replace: bool,
weights: np.ndarray | None,
random_state: np.random.RandomState,
) -> np.ndarray:
"""
Randomly sample `size` indices in `np.arange(obj_len)`

Parameters
----------
obj_len : int
The length of the indices being considered
size : int
The number of values to choose
replace : bool
Allow or disallow sampling of the same row more than once.
weights : np.ndarray[np.float64] or None
If None, equal probability weighting, otherwise weights according
to the vector normalized
random_state: np.random.RandomState
State used for the random sampling

Returns
-------
np.ndarray[np.intp]
"""
if weights is not None:
weight_sum = weights.sum()
if weight_sum != 0:
weights = weights / weight_sum
else:
raise ValueError("Invalid weights: weights sum to zero")

return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
np.intp, copy=False
)
11 changes: 8 additions & 3 deletions pandas/tests/frame/methods/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ def test_sample_wont_accept_n_and_frac(self, obj):
obj.sample(n=3, frac=0.3)

def test_sample_requires_positive_n_frac(self, obj):
msg = "A negative number of rows requested. Please provide positive value."
with pytest.raises(ValueError, match=msg):
with pytest.raises(
ValueError,
match="A negative number of rows requested. Please provide `n` >= 0",
):
obj.sample(n=-3)
with pytest.raises(ValueError, match=msg):
with pytest.raises(
ValueError,
match="A negative number of rows requested. Please provide `frac` >= 0",
):
obj.sample(frac=-0.3)

def test_sample_requires_integer_n(self, obj):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_groupby_sample_invalid_n_raises(n):
df = DataFrame({"a": [1, 2], "b": [1, 2]})

if n < 0:
msg = "Please provide positive value"
msg = "A negative number of rows requested. Please provide `n` >= 0."
else:
msg = "Only integers accepted as `n` values"

Expand Down