Skip to content

Commit fee2b87

Browse files
authored
PERF/REF: groupby sample (#42233)
1 parent e64b7ee commit fee2b87

File tree

7 files changed

+205
-99
lines changed

7 files changed

+205
-99
lines changed

asv_bench/benchmarks/groupby.py

+14
Original file line numberDiff line numberDiff line change
@@ -832,4 +832,18 @@ def function(values):
832832
self.grouper.agg(function, engine="cython")
833833

834834

835+
class Sample:
836+
def setup(self):
837+
N = 10 ** 3
838+
self.df = DataFrame({"a": np.zeros(N)})
839+
self.groups = np.arange(0, N)
840+
self.weights = np.ones(N)
841+
842+
def time_sample(self):
843+
self.df.groupby(self.groups).sample(n=1)
844+
845+
def time_sample_weights(self):
846+
self.df.groupby(self.groups).sample(n=1, weights=self.weights)
847+
848+
835849
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.4.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ Deprecations
105105

106106
Performance improvements
107107
~~~~~~~~~~~~~~~~~~~~~~~~
108-
-
108+
- Performance improvement in :meth:`.GroupBy.sample`, especially when ``weights`` argument provided (:issue:`34483`)
109109
-
110110

111111
.. ---------------------------------------------------------------------------

pandas/core/generic.py

+11-80
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
from pandas.core.missing import find_valid_index
138138
from pandas.core.ops import align_method_FRAME
139139
from pandas.core.reshape.concat import concat
140+
import pandas.core.sample as sample
140141
from pandas.core.shared_docs import _shared_docs
141142
from pandas.core.sorting import get_indexer_indexer
142143
from pandas.core.window import (
@@ -5146,7 +5147,7 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries:
51465147
@final
51475148
def sample(
51485149
self: FrameOrSeries,
5149-
n=None,
5150+
n: int | None = None,
51505151
frac: float | None = None,
51515152
replace: bool_t = False,
51525153
weights=None,
@@ -5273,92 +5274,22 @@ def sample(
52735274
axis = self._stat_axis_number
52745275

52755276
axis = self._get_axis_number(axis)
5276-
axis_length = self.shape[axis]
5277+
obj_len = self.shape[axis]
52775278

52785279
# Process random_state argument
52795280
rs = com.random_state(random_state)
52805281

5281-
# Check weights for compliance
5282-
if weights is not None:
5283-
5284-
# If a series, align with frame
5285-
if isinstance(weights, ABCSeries):
5286-
weights = weights.reindex(self.axes[axis])
5287-
5288-
# Strings acceptable if a dataframe and axis = 0
5289-
if isinstance(weights, str):
5290-
if isinstance(self, ABCDataFrame):
5291-
if axis == 0:
5292-
try:
5293-
weights = self[weights]
5294-
except KeyError as err:
5295-
raise KeyError(
5296-
"String passed to weights not a valid column"
5297-
) from err
5298-
else:
5299-
raise ValueError(
5300-
"Strings can only be passed to "
5301-
"weights when sampling from rows on "
5302-
"a DataFrame"
5303-
)
5304-
else:
5305-
raise ValueError(
5306-
"Strings cannot be passed as weights "
5307-
"when sampling from a Series."
5308-
)
5309-
5310-
if isinstance(self, ABCSeries):
5311-
func = self._constructor
5312-
else:
5313-
func = self._constructor_sliced
5314-
weights = func(weights, dtype="float64")
5282+
size = sample.process_sampling_size(n, frac, replace)
5283+
if size is None:
5284+
assert frac is not None
5285+
size = round(frac * obj_len)
53155286

5316-
if len(weights) != axis_length:
5317-
raise ValueError(
5318-
"Weights and axis to be sampled must be of same length"
5319-
)
5320-
5321-
if (weights == np.inf).any() or (weights == -np.inf).any():
5322-
raise ValueError("weight vector may not include `inf` values")
5323-
5324-
if (weights < 0).any():
5325-
raise ValueError("weight vector many not include negative values")
5326-
5327-
# If has nan, set to zero.
5328-
weights = weights.fillna(0)
5329-
5330-
# Renormalize if don't sum to 1
5331-
if weights.sum() != 1:
5332-
if weights.sum() != 0:
5333-
weights = weights / weights.sum()
5334-
else:
5335-
raise ValueError("Invalid weights: weights sum to zero")
5336-
5337-
weights = weights._values
5287+
if weights is not None:
5288+
weights = sample.preprocess_weights(self, weights, axis)
53385289

5339-
# If no frac or n, default to n=1.
5340-
if n is None and frac is None:
5341-
n = 1
5342-
elif frac is not None and frac > 1 and not replace:
5343-
raise ValueError(
5344-
"Replace has to be set to `True` when "
5345-
"upsampling the population `frac` > 1."
5346-
)
5347-
elif frac is None and n % 1 != 0:
5348-
raise ValueError("Only integers accepted as `n` values")
5349-
elif n is None and frac is not None:
5350-
n = round(frac * axis_length)
5351-
elif frac is not None:
5352-
raise ValueError("Please enter a value for `frac` OR `n`, not both")
5353-
5354-
# Check for negative sizes
5355-
if n < 0:
5356-
raise ValueError(
5357-
"A negative number of rows requested. Please provide positive value."
5358-
)
5290+
sampled_indices = sample.sample(obj_len, size, replace, weights, rs)
5291+
result = self.take(sampled_indices, axis=axis)
53595292

5360-
locs = rs.choice(axis_length, size=n, replace=replace, p=weights)
5361-
result = self.take(locs, axis=axis)
53625293
if ignore_index:
53635294
result.index = ibase.default_index(len(result))
53645295

pandas/core/groupby/groupby.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class providing the base-class of operations.
102102
MultiIndex,
103103
)
104104
from pandas.core.internals.blocks import ensure_block_shape
105+
import pandas.core.sample as sample
105106
from pandas.core.series import Series
106107
from pandas.core.sorting import get_group_index_sorter
107108
from pandas.core.util.numba_ import (
@@ -3270,26 +3271,37 @@ def sample(
32703271
2 blue 2
32713272
0 red 0
32723273
"""
3273-
from pandas.core.reshape.concat import concat
3274-
3274+
size = sample.process_sampling_size(n, frac, replace)
32753275
if weights is not None:
3276-
weights = Series(weights, index=self._selected_obj.index)
3277-
ws = [weights.iloc[idx] for idx in self.indices.values()]
3278-
else:
3279-
ws = [None] * self.ngroups
3276+
weights_arr = sample.preprocess_weights(
3277+
self._selected_obj, weights, axis=self.axis
3278+
)
32803279

3281-
if random_state is not None:
3282-
random_state = com.random_state(random_state)
3280+
random_state = com.random_state(random_state)
32833281

32843282
group_iterator = self.grouper.get_iterator(self._selected_obj, self.axis)
3285-
samples = [
3286-
obj.sample(
3287-
n=n, frac=frac, replace=replace, weights=w, random_state=random_state
3283+
3284+
sampled_indices = []
3285+
for labels, obj in group_iterator:
3286+
grp_indices = self.indices[labels]
3287+
group_size = len(grp_indices)
3288+
if size is not None:
3289+
sample_size = size
3290+
else:
3291+
assert frac is not None
3292+
sample_size = round(frac * group_size)
3293+
3294+
grp_sample = sample.sample(
3295+
group_size,
3296+
size=sample_size,
3297+
replace=replace,
3298+
weights=None if weights is None else weights_arr[grp_indices],
3299+
random_state=random_state,
32883300
)
3289-
for (_, obj), w in zip(group_iterator, ws)
3290-
]
3301+
sampled_indices.append(grp_indices[grp_sample])
32913302

3292-
return concat(samples, axis=self.axis)
3303+
sampled_indices = np.concatenate(sampled_indices)
3304+
return self._selected_obj.take(sampled_indices, axis=self.axis)
32933305

32943306

32953307
@doc(GroupBy)

pandas/core/sample.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
Module containing utilities for NDFrame.sample() and .GroupBy.sample()
3+
"""
4+
from __future__ import annotations
5+
6+
import numpy as np
7+
8+
from pandas._libs import lib
9+
from pandas._typing import FrameOrSeries
10+
11+
from pandas.core.dtypes.generic import (
12+
ABCDataFrame,
13+
ABCSeries,
14+
)
15+
16+
17+
def preprocess_weights(obj: FrameOrSeries, weights, axis: int) -> np.ndarray:
18+
"""
19+
Process and validate the `weights` argument to `NDFrame.sample` and
20+
`.GroupBy.sample`.
21+
22+
Returns `weights` as an ndarray[np.float64], validated except for normalizing
23+
weights (because that must be done groupwise in groupby sampling).
24+
"""
25+
# If a series, align with frame
26+
if isinstance(weights, ABCSeries):
27+
weights = weights.reindex(obj.axes[axis])
28+
29+
# Strings acceptable if a dataframe and axis = 0
30+
if isinstance(weights, str):
31+
if isinstance(obj, ABCDataFrame):
32+
if axis == 0:
33+
try:
34+
weights = obj[weights]
35+
except KeyError as err:
36+
raise KeyError(
37+
"String passed to weights not a valid column"
38+
) from err
39+
else:
40+
raise ValueError(
41+
"Strings can only be passed to "
42+
"weights when sampling from rows on "
43+
"a DataFrame"
44+
)
45+
else:
46+
raise ValueError(
47+
"Strings cannot be passed as weights when sampling from a Series."
48+
)
49+
50+
if isinstance(obj, ABCSeries):
51+
func = obj._constructor
52+
else:
53+
func = obj._constructor_sliced
54+
55+
weights = func(weights, dtype="float64")._values
56+
57+
if len(weights) != obj.shape[axis]:
58+
raise ValueError("Weights and axis to be sampled must be of same length")
59+
60+
if lib.has_infs(weights):
61+
raise ValueError("weight vector may not include `inf` values")
62+
63+
if (weights < 0).any():
64+
raise ValueError("weight vector many not include negative values")
65+
66+
weights[np.isnan(weights)] = 0
67+
return weights
68+
69+
70+
def process_sampling_size(
71+
n: int | None, frac: float | None, replace: bool
72+
) -> int | None:
73+
"""
74+
Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
75+
`.GroupBy.sample`.
76+
77+
Returns None if `frac` should be used (variable sampling sizes), otherwise returns
78+
the constant sampling size.
79+
"""
80+
# If no frac or n, default to n=1.
81+
if n is None and frac is None:
82+
n = 1
83+
elif n is not None and frac is not None:
84+
raise ValueError("Please enter a value for `frac` OR `n`, not both")
85+
elif n is not None:
86+
if n < 0:
87+
raise ValueError(
88+
"A negative number of rows requested. Please provide `n` >= 0."
89+
)
90+
if n % 1 != 0:
91+
raise ValueError("Only integers accepted as `n` values")
92+
else:
93+
assert frac is not None # for mypy
94+
if frac > 1 and not replace:
95+
raise ValueError(
96+
"Replace has to be set to `True` when "
97+
"upsampling the population `frac` > 1."
98+
)
99+
if frac < 0:
100+
raise ValueError(
101+
"A negative number of rows requested. Please provide `frac` >= 0."
102+
)
103+
104+
return n
105+
106+
107+
def sample(
108+
obj_len: int,
109+
size: int,
110+
replace: bool,
111+
weights: np.ndarray | None,
112+
random_state: np.random.RandomState,
113+
) -> np.ndarray:
114+
"""
115+
Randomly sample `size` indices in `np.arange(obj_len)`
116+
117+
Parameters
118+
----------
119+
obj_len : int
120+
The length of the indices being considered
121+
size : int
122+
The number of values to choose
123+
replace : bool
124+
Allow or disallow sampling of the same row more than once.
125+
weights : np.ndarray[np.float64] or None
126+
If None, equal probability weighting, otherwise weights according
127+
to the vector normalized
128+
random_state: np.random.RandomState
129+
State used for the random sampling
130+
131+
Returns
132+
-------
133+
np.ndarray[np.intp]
134+
"""
135+
if weights is not None:
136+
weight_sum = weights.sum()
137+
if weight_sum != 0:
138+
weights = weights / weight_sum
139+
else:
140+
raise ValueError("Invalid weights: weights sum to zero")
141+
142+
return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
143+
np.intp, copy=False
144+
)

pandas/tests/frame/methods/test_sample.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,15 @@ def test_sample_wont_accept_n_and_frac(self, obj):
8484
obj.sample(n=3, frac=0.3)
8585

8686
def test_sample_requires_positive_n_frac(self, obj):
87-
msg = "A negative number of rows requested. Please provide positive value."
88-
with pytest.raises(ValueError, match=msg):
87+
with pytest.raises(
88+
ValueError,
89+
match="A negative number of rows requested. Please provide `n` >= 0",
90+
):
8991
obj.sample(n=-3)
90-
with pytest.raises(ValueError, match=msg):
92+
with pytest.raises(
93+
ValueError,
94+
match="A negative number of rows requested. Please provide `frac` >= 0",
95+
):
9196
obj.sample(frac=-0.3)
9297

9398
def test_sample_requires_integer_n(self, obj):

pandas/tests/groupby/test_sample.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_groupby_sample_invalid_n_raises(n):
7878
df = DataFrame({"a": [1, 2], "b": [1, 2]})
7979

8080
if n < 0:
81-
msg = "Please provide positive value"
81+
msg = "A negative number of rows requested. Please provide `n` >= 0."
8282
else:
8383
msg = "Only integers accepted as `n` values"
8484

0 commit comments

Comments
 (0)