Skip to content

PERF/ENH: allow Generator in sampling methods #42243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
-
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
-

.. ---------------------------------------------------------------------------
Expand Down
12 changes: 11 additions & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,17 @@
JSONSerializable = Optional[Union[PythonScalar, List, Dict]]
Frequency = Union[str, "DateOffset"]
Axes = Collection[Any]
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]

if TYPE_CHECKING:
RandomState = Union[
int,
ArrayLike,
np.random.Generator,
np.random.BitGenerator,
np.random.RandomState,
]
else:
RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm why is BitGenerator not in the 2nd part?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BitGenerator isn't exposed until numpy >= 1.18. Have added a comment about this and removed the else clause (since that would mean we're not type checking anyway so the definition doesn't matter)


# dtypes
NpDtype = Union[str, np.dtype]
Expand Down
29 changes: 22 additions & 7 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Iterable,
Iterator,
cast,
overload,
)
import warnings

Expand All @@ -29,6 +30,7 @@
from pandas._libs import lib
from pandas._typing import (
AnyArrayLike,
ArrayLike,
NpDtype,
RandomState,
Scalar,
Expand Down Expand Up @@ -389,16 +391,28 @@ def standardize_mapping(into):
return into


def random_state(state: RandomState | None = None) -> np.random.RandomState:
@overload
def random_state(state: np.random.Generator) -> np.random.Generator:
...


@overload
def random_state(
state: int | ArrayLike | np.random.BitGenerator | np.random.RandomState | None,
) -> np.random.RandomState:
...


def random_state(state: RandomState | None = None):
"""
Helper function for processing random_state arguments.

Parameters
----------
state : int, array-like, BitGenerator, np.random.RandomState, None.
state : int, array-like, BitGenerator, Generator, np.random.RandomState, None.
If receives an int, array-like, or BitGenerator, passes to
np.random.RandomState() as seed.
If receives an np.random.RandomState object, just returns object.
If receives an np.random RandomState or Generator, just returns that unchanged.
If receives `None`, returns np.random.
If receives anything else, raises an informative ValueError.

Expand All @@ -411,7 +425,7 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:

Returns
-------
np.random.RandomState or np.random if state is None
np.random.RandomState or np.random.Generator. If state is None, returns np.random

"""
if (
Expand All @@ -434,12 +448,13 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState:
return np.random.RandomState(state) # type: ignore[arg-type]
elif isinstance(state, np.random.RandomState):
return state
elif isinstance(state, np.random.Generator):
return state
elif state is None:
# error: Incompatible return value type (got Module, expected "RandomState")
return np.random # type: ignore[return-value]
return np.random
else:
raise ValueError(
"random_state must be an integer, array-like, a BitGenerator, "
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)

Expand Down
15 changes: 10 additions & 5 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5178,14 +5178,19 @@ def sample(
If weights do not sum to 1, they will be normalized to sum to 1.
Missing values in the weights column will be treated as zero.
Infinite values not allowed.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator, seed for random number generator.
If np.random.RandomState, use as numpy RandomState object.
random_state : int, array-like, BitGenerator, np.random.RandomState,
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
random number generator. If np.random.RandomState or np.random.Generator,
use as given.

.. versionchanged:: 1.1.0

array-like and BitGenerator (for NumPy>=1.17) object now passed to
np.random.RandomState() as seed
array-like and BitGenerator object now passed to np.random.RandomState()
as seed

.. versionchanged:: 1.4.0

np.random.Generator objects now accepted

axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name. Default is stat axis
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3211,9 +3211,14 @@ def sample(
sampling probabilities after normalization within each group.
Values must be non-negative with at least one positive element
within each group.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator, seed for random number generator.
If np.random.RandomState, use as numpy RandomState object.
random_state : int, array-like, BitGenerator, np.random.RandomState,
np.random.Generator, optional. If int, array-like, or BitGenerator, seed for
random number generator. If np.random.RandomState or np.random.Generator,
use as given.

.. versionchanged:: 1.4.0

np.random.Generator objects now accepted

Returns
-------
Expand Down
20 changes: 18 additions & 2 deletions pandas/tests/frame/methods/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def test_sample_lengths(self, obj):
def test_sample_invalid_random_state(self, obj):
# Check for error when random_state argument invalid.
msg = (
"random_state must be an integer, array-like, a BitGenerator, a numpy "
"RandomState, or None"
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)
with pytest.raises(ValueError, match=msg):
obj.sample(random_state="a_string")
Expand Down Expand Up @@ -177,6 +177,22 @@ def test_sample_random_state(self, func_str, arg, frame_or_series):
expected = obj.sample(n=3, random_state=com.random_state(eval(func_str)(arg)))
tm.assert_equal(result, expected)

def test_sample_generator(self, frame_or_series):
# GH#38100
obj = frame_or_series(np.arange(100))
rng = np.random.default_rng()

# Consecutive calls should advance the seed
result1 = obj.sample(n=50, random_state=rng)
result2 = obj.sample(n=50, random_state=rng)
assert not (result1.index.values == result2.index.values).all()

# Matching generator initialization must give same result
# Consecutive calls should advance the seed
result1 = obj.sample(n=50, random_state=np.random.default_rng(11))
result2 = obj.sample(n=50, random_state=np.random.default_rng(11))
tm.assert_equal(result1, result2)

def test_sample_upsampling_without_replacement(self, frame_or_series):
# GH#27451

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_random_state():

# Error for floats or strings
msg = (
"random_state must be an integer, array-like, a BitGenerator, "
"random_state must be an integer, array-like, a BitGenerator, Generator, "
"a numpy RandomState, or None"
)
with pytest.raises(ValueError, match=msg):
Expand Down