Skip to content

PERF/ENH: allow Generator in sampling methods #42243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
-
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
-

.. ---------------------------------------------------------------------------
Expand Down
11 changes: 10 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,16 @@
JSONSerializable = Optional[Union[PythonScalar, List, Dict]]
Frequency = Union[str, "DateOffset"]
Axes = Collection[Any]
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]

# BitGenerator isn't exposed until 1.18
if TYPE_CHECKING:
RandomState = Union[
int,
ArrayLike,
np.random.Generator,
np.random.BitGenerator,
np.random.RandomState,
]

# dtypes
NpDtype = Union[str, np.dtype]
Expand Down
33 changes: 25 additions & 8 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Iterable,
Iterator,
cast,
overload,
)
import warnings

Expand All @@ -29,8 +30,8 @@
from pandas._libs import lib
from pandas._typing import (
AnyArrayLike,
ArrayLike,
NpDtype,
RandomState,
Scalar,
T,
)
Expand All @@ -52,6 +53,9 @@
from pandas.core.dtypes.missing import isna

if TYPE_CHECKING:
# Includes BitGenerator, which only exists >= 1.18
from pandas._typing import RandomState

from pandas import Index


Expand Down Expand Up @@ -389,16 +393,28 @@ def standardize_mapping(into):
return into


def random_state(state: RandomState | None = None) -> np.random.RandomState:
@overload
def random_state(state: np.random.Generator) -> np.random.Generator:
...


@overload
def random_state(
state: int | ArrayLike | np.random.BitGenerator | np.random.RandomState | None,
) -> np.random.RandomState:
...


def random_state(state: RandomState | None = None):
"""
Helper function for processing random_state arguments.

Parameters
----------
state : int, array-like, BitGenerator, np.random.RandomState, None.
state : int, array-like, BitGenerator, Generator, np.random.RandomState, None.
If receives an int, array-like, or BitGenerator, passes to
np.random.RandomState() as seed.
If receives an np.random.RandomState object, just returns object.
If receives an np.random RandomState or Generator, just returns that unchanged.
If receives `None`, returns np.random.
If receives anything else, raises an informative ValueError.

Expand All @@ -411,7 +427,7 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:

Returns
-------
np.random.RandomState or np.random if state is None
np.random.RandomState or np.random.Generator. If state is None, returns np.random

"""
if (
Expand All @@ -434,12 +450,13 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:
return np.random.RandomState(state) # type: ignore[arg-type]
elif isinstance(state, np.random.RandomState):
return state
elif isinstance(state, np.random.Generator):
return state
elif state is None:
# error: Incompatible return value type (got Module, expected "RandomState")
return np.random # type: ignore[return-value]
return np.random
else:
raise ValueError(
"random_state must be an integer, array-like, a BitGenerator, "
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)

Expand Down
17 changes: 11 additions & 6 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
JSONSerializable,
Level,
Manager,
RandomState,
Renamer,
StorageOptions,
T,
Expand Down Expand Up @@ -158,6 +157,7 @@
from typing import Literal

from pandas._libs.tslibs import BaseOffset
from pandas._typing import RandomState

from pandas.core.frame import DataFrame
from pandas.core.resample import Resampler
Expand Down Expand Up @@ -5182,14 +5182,19 @@ def sample(
If weights do not sum to 1, they will be normalized to sum to 1.
Missing values in the weights column will be treated as zero.
Infinite values not allowed.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator, seed for random number generator.
If np.random.RandomState, use as numpy RandomState object.
random_state : int, array-like, BitGenerator, np.random.RandomState,
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
random number generator. If np.random.RandomState or np.random.Generator,
use as given.

.. versionchanged:: 1.1.0

array-like and BitGenerator (for NumPy>=1.17) object now passed to
np.random.RandomState() as seed
array-like and BitGenerator object now passed to np.random.RandomState()
as seed

.. versionchanged:: 1.4.0

np.random.Generator objects now accepted

axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name. Default is stat axis
Expand Down
14 changes: 10 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class providing the base-class of operations.
FrameOrSeries,
FrameOrSeriesUnion,
IndexLabel,
RandomState,
Scalar,
T,
final,
Expand Down Expand Up @@ -113,6 +112,8 @@ class providing the base-class of operations.
if TYPE_CHECKING:
from typing import Literal

from pandas._typing import RandomState

_common_see_also = """
See Also
--------
Expand Down Expand Up @@ -3212,9 +3213,14 @@ def sample(
sampling probabilities after normalization within each group.
Values must be non-negative with at least one positive element
within each group.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator, seed for random number generator.
If np.random.RandomState, use as numpy RandomState object.
random_state : int, array-like, BitGenerator, np.random.RandomState,
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
random number generator. If np.random.RandomState or np.random.Generator,
use as given.

.. versionchanged:: 1.4.0

np.random.Generator objects now accepted

Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def sample(
size: int,
replace: bool,
weights: np.ndarray | None,
random_state: np.random.RandomState,
random_state: np.random.RandomState | np.random.Generator,
) -> np.ndarray:
"""
Randomly sample `size` indices in `np.arange(obj_len)`
Expand All @@ -125,7 +125,7 @@ def sample(
weights : np.ndarray[np.float64] or None
If None, equal probability weighting, otherwise weights according
to the vector normalized
random_state: np.random.RandomState
random_state: np.random.RandomState or np.random.Generator
State used for the random sampling

Returns
Expand Down
20 changes: 18 additions & 2 deletions pandas/tests/frame/methods/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_sample_lengths(self, obj):
def test_sample_invalid_random_state(self, obj):
# Check for error when random_state argument invalid.
msg = (
"random_state must be an integer, array-like, a BitGenerator, a numpy "
"RandomState, or None"
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)
with pytest.raises(ValueError, match=msg):
obj.sample(random_state="a_string")
Expand Down Expand Up @@ -182,6 +182,22 @@ def test_sample_random_state(self, func_str, arg, frame_or_series):
expected = obj.sample(n=3, random_state=com.random_state(eval(func_str)(arg)))
tm.assert_equal(result, expected)

def test_sample_generator(self, frame_or_series):
# GH#38100
obj = frame_or_series(np.arange(100))
rng = np.random.default_rng()

# Consecutive calls should advance the seed
result1 = obj.sample(n=50, random_state=rng)
result2 = obj.sample(n=50, random_state=rng)
assert not (result1.index.values == result2.index.values).all()

# Matching generator initialization must give same result
# Consecutive calls should advance the seed
result1 = obj.sample(n=50, random_state=np.random.default_rng(11))
result2 = obj.sample(n=50, random_state=np.random.default_rng(11))
tm.assert_equal(result1, result2)

def test_sample_upsampling_without_replacement(self, frame_or_series):
# GH#27451

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_random_state():

# Error for floats or strings
msg = (
"random_state must be an integer, array-like, a BitGenerator, "
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)
with pytest.raises(ValueError, match=msg):
Expand Down