diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 45c280b89ea28..02ea989ed5525 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -29,7 +29,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ -- +- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`) - .. --------------------------------------------------------------------------- diff --git a/pandas/_typing.py b/pandas/_typing.py index ccf7699fb0b4b..d14699844094c 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -125,7 +125,16 @@ JSONSerializable = Optional[Union[PythonScalar, List, Dict]] Frequency = Union[str, "DateOffset"] Axes = Collection[Any] -RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState] + +# BitGenerator isn't exposed until 1.18 +if TYPE_CHECKING: + RandomState = Union[ + int, + ArrayLike, + np.random.Generator, + np.random.BitGenerator, + np.random.RandomState, + ] # dtypes NpDtype = Union[str, np.dtype] diff --git a/pandas/core/common.py b/pandas/core/common.py index dda3e39870ffb..84d18f81292c9 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -21,6 +21,7 @@ Iterable, Iterator, cast, + overload, ) import warnings @@ -29,8 +30,8 @@ from pandas._libs import lib from pandas._typing import ( AnyArrayLike, + ArrayLike, NpDtype, - RandomState, Scalar, T, ) @@ -52,6 +53,9 @@ from pandas.core.dtypes.missing import isna if TYPE_CHECKING: + # Includes BitGenerator, which only exists >= 1.18 + from pandas._typing import RandomState + from pandas import Index @@ -389,16 +393,28 @@ def standardize_mapping(into): return into -def random_state(state: RandomState | None = None) -> np.random.RandomState: +@overload +def random_state(state: np.random.Generator) -> np.random.Generator: + ... + + +@overload +def random_state( + state: int | ArrayLike | np.random.BitGenerator | np.random.RandomState | None, +) -> np.random.RandomState: + ... + + +def random_state(state: RandomState | None = None): """ Helper function for processing random_state arguments. Parameters ---------- - state : int, array-like, BitGenerator, np.random.RandomState, None. + state : int, array-like, BitGenerator, Generator, np.random.RandomState, None. If receives an int, array-like, or BitGenerator, passes to np.random.RandomState() as seed. - If receives an np.random.RandomState object, just returns object. + If receives an np.random RandomState or Generator, just returns that unchanged. If receives `None`, returns np.random. If receives anything else, raises an informative ValueError. @@ -411,7 +427,7 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState: Returns ------- - np.random.RandomState or np.random if state is None + np.random.RandomState or np.random.Generator. If state is None, returns np.random """ if ( @@ -434,12 +450,13 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState: return np.random.RandomState(state) # type: ignore[arg-type] elif isinstance(state, np.random.RandomState): return state + elif isinstance(state, np.random.Generator): + return state elif state is None: - # error: Incompatible return value type (got Module, expected "RandomState") - return np.random # type: ignore[return-value] + return np.random else: raise ValueError( - "random_state must be an integer, array-like, a BitGenerator, " + "random_state must be an integer, array-like, a BitGenerator, Generator, " "a numpy RandomState, or None" ) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 1abb01099f977..47a8d5248cd32 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -46,7 +46,6 @@ JSONSerializable, Level, Manager, - RandomState, Renamer, StorageOptions, T, @@ -158,6 +157,7 @@ from typing import Literal from pandas._libs.tslibs import BaseOffset + from pandas._typing import RandomState from pandas.core.frame import DataFrame from pandas.core.resample import Resampler @@ -5182,14 +5182,19 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. - random_state : int, array-like, BitGenerator, np.random.RandomState, optional - If int, array-like, or BitGenerator, seed for random number generator. - If np.random.RandomState, use as numpy RandomState object. + random_state : int, array-like, BitGenerator, np.random.RandomState, + np.random.Generator, optional. If int, array-like, or BitGenerator, seed for + random number generator. If np.random.RandomState or np.random.Generator, + use as given. .. versionchanged:: 1.1.0 - array-like and BitGenerator (for NumPy>=1.17) object now passed to - np.random.RandomState() as seed + array-like and BitGenerator object now passed to np.random.RandomState() + as seed + + .. versionchanged:: 1.4.0 + + np.random.Generator objects now accepted axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 8fb50db2e33f2..3ab6c62323f5e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -47,7 +47,6 @@ class providing the base-class of operations. FrameOrSeries, FrameOrSeriesUnion, IndexLabel, - RandomState, Scalar, T, final, @@ -113,6 +112,8 @@ class providing the base-class of operations. if TYPE_CHECKING: from typing import Literal + from pandas._typing import RandomState + _common_see_also = """ See Also -------- @@ -3212,9 +3213,14 @@ def sample( sampling probabilities after normalization within each group. Values must be non-negative with at least one positive element within each group. - random_state : int, array-like, BitGenerator, np.random.RandomState, optional - If int, array-like, or BitGenerator, seed for random number generator. - If np.random.RandomState, use as numpy RandomState object. + random_state : int, array-like, BitGenerator, np.random.RandomState, + np.random.Generator, optional. If int, array-like, or BitGenerator, seed for + random number generator. If np.random.RandomState or np.random.Generator, + use as given. + + .. versionchanged:: 1.4.0 + + np.random.Generator objects now accepted Returns ------- diff --git a/pandas/core/sample.py b/pandas/core/sample.py index 4798f385d523c..e4bad22e8e43c 100644 --- a/pandas/core/sample.py +++ b/pandas/core/sample.py @@ -109,7 +109,7 @@ def sample( size: int, replace: bool, weights: np.ndarray | None, - random_state: np.random.RandomState, + random_state: np.random.RandomState | np.random.Generator, ) -> np.ndarray: """ Randomly sample `size` indices in `np.arange(obj_len)` @@ -125,7 +125,7 @@ def sample( weights : np.ndarray[np.float64] or None If None, equal probability weighting, otherwise weights according to the vector normalized - random_state: np.random.RandomState + random_state: np.random.RandomState or np.random.Generator State used for the random sampling Returns diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index fc90bcbf5fbdc..90cdda7d7b3e9 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -71,8 +71,8 @@ def test_sample_lengths(self, obj): def test_sample_invalid_random_state(self, obj): # Check for error when random_state argument invalid. msg = ( - "random_state must be an integer, array-like, a BitGenerator, a numpy " - "RandomState, or None" + "random_state must be an integer, array-like, a BitGenerator, Generator, " + "a numpy RandomState, or None" ) with pytest.raises(ValueError, match=msg): obj.sample(random_state="a_string") @@ -182,6 +182,22 @@ def test_sample_random_state(self, func_str, arg, frame_or_series): expected = obj.sample(n=3, random_state=com.random_state(eval(func_str)(arg))) tm.assert_equal(result, expected) + def test_sample_generator(self, frame_or_series): + # GH#38100 + obj = frame_or_series(np.arange(100)) + rng = np.random.default_rng() + + # Consecutive calls should advance the seed + result1 = obj.sample(n=50, random_state=rng) + result2 = obj.sample(n=50, random_state=rng) + assert not (result1.index.values == result2.index.values).all() + + # Matching generator initialization must give same result + # Consecutive calls should advance the seed + result1 = obj.sample(n=50, random_state=np.random.default_rng(11)) + result2 = obj.sample(n=50, random_state=np.random.default_rng(11)) + tm.assert_equal(result1, result2) + def test_sample_upsampling_without_replacement(self, frame_or_series): # GH#27451 diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 93c95b3004876..f123f0ac9f5f8 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -84,7 +84,7 @@ def test_random_state(): # Error for floats or strings msg = ( - "random_state must be an integer, array-like, a BitGenerator, " + "random_state must be an integer, array-like, a BitGenerator, Generator, " "a numpy RandomState, or None" ) with pytest.raises(ValueError, match=msg):