Skip to content

Commit 9a6ce3c

Browse files
authored
PERF/ENH: allow Generator in sampling methods (pandas-dev#42243)
1 parent 214bdfb commit 9a6ce3c

File tree

8 files changed

+78
-25
lines changed

8 files changed

+78
-25
lines changed

doc/source/whatsnew/v1.4.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enhancement2
2929

3030
Other enhancements
3131
^^^^^^^^^^^^^^^^^^
32-
-
32+
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
3333
-
3434

3535
.. ---------------------------------------------------------------------------

pandas/_typing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,16 @@
119119
JSONSerializable = Optional[Union[PythonScalar, List, Dict]]
120120
Frequency = Union[str, "DateOffset"]
121121
Axes = Collection[Any]
122-
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]
122+
123+
# BitGenerator isn't exposed until 1.18
124+
if TYPE_CHECKING:
125+
RandomState = Union[
126+
int,
127+
ArrayLike,
128+
np.random.Generator,
129+
np.random.BitGenerator,
130+
np.random.RandomState,
131+
]
123132

124133
# dtypes
125134
NpDtype = Union[str, np.dtype]

pandas/core/common.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Iterable,
2222
Iterator,
2323
cast,
24+
overload,
2425
)
2526
import warnings
2627

@@ -29,8 +30,8 @@
2930
from pandas._libs import lib
3031
from pandas._typing import (
3132
AnyArrayLike,
33+
ArrayLike,
3234
NpDtype,
33-
RandomState,
3435
Scalar,
3536
T,
3637
)
@@ -52,6 +53,9 @@
5253
from pandas.core.dtypes.missing import isna
5354

5455
if TYPE_CHECKING:
56+
# Includes BitGenerator, which only exists >= 1.18
57+
from pandas._typing import RandomState
58+
5559
from pandas import Index
5660

5761

@@ -389,16 +393,28 @@ def standardize_mapping(into):
389393
return into
390394

391395

392-
def random_state(state: RandomState | None = None) -> np.random.RandomState:
396+
@overload
397+
def random_state(state: np.random.Generator) -> np.random.Generator:
398+
...
399+
400+
401+
@overload
402+
def random_state(
403+
state: int | ArrayLike | np.random.BitGenerator | np.random.RandomState | None,
404+
) -> np.random.RandomState:
405+
...
406+
407+
408+
def random_state(state: RandomState | None = None):
393409
"""
394410
Helper function for processing random_state arguments.
395411
396412
Parameters
397413
----------
398-
state : int, array-like, BitGenerator, np.random.RandomState, None.
414+
state : int, array-like, BitGenerator, Generator, np.random.RandomState, None.
399415
If receives an int, array-like, or BitGenerator, passes to
400416
np.random.RandomState() as seed.
401-
If receives an np.random.RandomState object, just returns object.
417+
If receives an np.random RandomState or Generator, just returns that unchanged.
402418
If receives `None`, returns np.random.
403419
If receives anything else, raises an informative ValueError.
404420
@@ -411,7 +427,7 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:
411427
412428
Returns
413429
-------
414-
np.random.RandomState or np.random if state is None
430+
np.random.RandomState or np.random.Generator. If state is None, returns np.random
415431
416432
"""
417433
if (
@@ -434,12 +450,13 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:
434450
return np.random.RandomState(state) # type: ignore[arg-type]
435451
elif isinstance(state, np.random.RandomState):
436452
return state
453+
elif isinstance(state, np.random.Generator):
454+
return state
437455
elif state is None:
438-
# error: Incompatible return value type (got Module, expected "RandomState")
439-
return np.random # type: ignore[return-value]
456+
return np.random
440457
else:
441458
raise ValueError(
442-
"random_state must be an integer, array-like, a BitGenerator, "
459+
"random_state must be an integer, array-like, a BitGenerator, Generator, "
443460
"a numpy RandomState, or None"
444461
)
445462

pandas/core/generic.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
JSONSerializable,
4747
Level,
4848
Manager,
49-
RandomState,
5049
Renamer,
5150
StorageOptions,
5251
T,
@@ -158,6 +157,7 @@
158157
from typing import Literal
159158

160159
from pandas._libs.tslibs import BaseOffset
160+
from pandas._typing import RandomState
161161

162162
from pandas.core.frame import DataFrame
163163
from pandas.core.resample import Resampler
@@ -5184,14 +5184,19 @@ def sample(
51845184
If weights do not sum to 1, they will be normalized to sum to 1.
51855185
Missing values in the weights column will be treated as zero.
51865186
Infinite values not allowed.
5187-
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
5188-
If int, array-like, or BitGenerator, seed for random number generator.
5189-
If np.random.RandomState, use as numpy RandomState object.
5187+
random_state : int, array-like, BitGenerator, np.random.RandomState,
5188+
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
5189+
random number generator. If np.random.RandomState or np.random.Generator,
5190+
use as given.
51905191
51915192
.. versionchanged:: 1.1.0
51925193
5193-
array-like and BitGenerator (for NumPy>=1.17) object now passed to
5194-
np.random.RandomState() as seed
5194+
array-like and BitGenerator object now passed to np.random.RandomState()
5195+
as seed
5196+
5197+
.. versionchanged:: 1.4.0
5198+
5199+
np.random.Generator objects now accepted
51955200
51965201
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
51975202
Axis to sample. Accepts axis number or name. Default is stat axis

pandas/core/groupby/groupby.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class providing the base-class of operations.
4646
F,
4747
FrameOrSeries,
4848
IndexLabel,
49-
RandomState,
5049
Scalar,
5150
T,
5251
final,
@@ -112,6 +111,8 @@ class providing the base-class of operations.
112111
if TYPE_CHECKING:
113112
from typing import Literal
114113

114+
from pandas._typing import RandomState
115+
115116
_common_see_also = """
116117
See Also
117118
--------
@@ -3211,9 +3212,14 @@ def sample(
32113212
sampling probabilities after normalization within each group.
32123213
Values must be non-negative with at least one positive element
32133214
within each group.
3214-
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
3215-
If int, array-like, or BitGenerator, seed for random number generator.
3216-
If np.random.RandomState, use as numpy RandomState object.
3215+
random_state : int, array-like, BitGenerator, np.random.RandomState,
3216+
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
3217+
random number generator. If np.random.RandomState or np.random.Generator,
3218+
use as given.
3219+
3220+
.. versionchanged:: 1.4.0
3221+
3222+
np.random.Generator objects now accepted
32173223
32183224
Returns
32193225
-------

pandas/core/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def sample(
109109
size: int,
110110
replace: bool,
111111
weights: np.ndarray | None,
112-
random_state: np.random.RandomState,
112+
random_state: np.random.RandomState | np.random.Generator,
113113
) -> np.ndarray:
114114
"""
115115
Randomly sample `size` indices in `np.arange(obj_len)`
@@ -125,7 +125,7 @@ def sample(
125125
weights : np.ndarray[np.float64] or None
126126
If None, equal probability weighting, otherwise weights according
127127
to the vector normalized
128-
random_state: np.random.RandomState
128+
random_state: np.random.RandomState or np.random.Generator
129129
State used for the random sampling
130130
131131
Returns

pandas/tests/frame/methods/test_sample.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_sample_lengths(self, obj):
7171
def test_sample_invalid_random_state(self, obj):
7272
# Check for error when random_state argument invalid.
7373
msg = (
74-
"random_state must be an integer, array-like, a BitGenerator, a numpy "
75-
"RandomState, or None"
74+
"random_state must be an integer, array-like, a BitGenerator, Generator, "
75+
"a numpy RandomState, or None"
7676
)
7777
with pytest.raises(ValueError, match=msg):
7878
obj.sample(random_state="a_string")
@@ -182,6 +182,22 @@ def test_sample_random_state(self, func_str, arg, frame_or_series):
182182
expected = obj.sample(n=3, random_state=com.random_state(eval(func_str)(arg)))
183183
tm.assert_equal(result, expected)
184184

185+
def test_sample_generator(self, frame_or_series):
186+
# GH#38100
187+
obj = frame_or_series(np.arange(100))
188+
rng = np.random.default_rng()
189+
190+
# Consecutive calls should advance the seed
191+
result1 = obj.sample(n=50, random_state=rng)
192+
result2 = obj.sample(n=50, random_state=rng)
193+
assert not (result1.index.values == result2.index.values).all()
194+
195+
# Matching generator initialization must give same result
196+
# Consecutive calls should advance the seed
197+
result1 = obj.sample(n=50, random_state=np.random.default_rng(11))
198+
result2 = obj.sample(n=50, random_state=np.random.default_rng(11))
199+
tm.assert_equal(result1, result2)
200+
185201
def test_sample_upsampling_without_replacement(self, frame_or_series):
186202
# GH#27451
187203

pandas/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_random_state():
8484

8585
# Error for floats or strings
8686
msg = (
87-
"random_state must be an integer, array-like, a BitGenerator, "
87+
"random_state must be an integer, array-like, a BitGenerator, Generator, "
8888
"a numpy RandomState, or None"
8989
)
9090
with pytest.raises(ValueError, match=msg):

0 commit comments

Comments
 (0)