From 83c7d7ebae1668cadec0a40409f34584474b86c2 Mon Sep 17 00:00:00 2001 From: Matthew Zeitlin Date: Fri, 25 Jun 2021 18:48:25 -0400 Subject: [PATCH 1/5] PERF/ENH: allow Generator in sampling methods --- doc/source/whatsnew/v1.4.0.rst | 2 +- pandas/_typing.py | 4 ++- pandas/core/common.py | 31 ++++++++++++++++------- pandas/core/generic.py | 15 +++++++---- pandas/core/groupby/groupby.py | 11 +++++--- pandas/tests/frame/methods/test_sample.py | 20 +++++++++++++-- 6 files changed, 62 insertions(+), 21 deletions(-) diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index 81545ada63ce5..0f00f199a429a 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -29,7 +29,7 @@ enhancement2 Other enhancements ^^^^^^^^^^^^^^^^^^ -- +- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`) - .. --------------------------------------------------------------------------- diff --git a/pandas/_typing.py b/pandas/_typing.py index a9852dd4b13cf..8071d9172c869 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -122,7 +122,9 @@ JSONSerializable = Optional[Union[PythonScalar, List, Dict]] Frequency = Union[str, "DateOffset"] Axes = Collection[Any] -RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState] +RandomState = Union[ + int, ArrayLike, np.random.Generator, np.random.BitGenerator, np.random.RandomState +] # dtypes NpDtype = Union[str, np.dtype] diff --git a/pandas/core/common.py b/pandas/core/common.py index dda3e39870ffb..c5ca6696dffe9 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -21,6 +21,7 @@ Iterable, Iterator, cast, + overload, ) import warnings @@ -29,6 +30,7 @@ from pandas._libs import lib from pandas._typing import ( AnyArrayLike, + ArrayLike, NpDtype, RandomState, Scalar, @@ -389,16 +391,28 @@ def standardize_mapping(into): return into -def random_state(state: RandomState | None = None) -> np.random.RandomState: +@overload +def random_state(state: np.random.Generator) -> np.random.Generator: + ... + + +@overload +def random_state( + state: int | ArrayLike | np.random.BitGenerator | np.random.RandomState | None, +) -> np.random.RandomState: + ... + + +def random_state(state: RandomState | None): """ Helper function for processing random_state arguments. Parameters ---------- - state : int, array-like, BitGenerator, np.random.RandomState, None. + state : int, array-like, BitGenerator, Generator, np.random.RandomState, None. If receives an int, array-like, or BitGenerator, passes to np.random.RandomState() as seed. - If receives an np.random.RandomState object, just returns object. + If receives an np.random RandomState or Generator, just returns that unchanged. If receives `None`, returns np.random. If receives anything else, raises an informative ValueError. @@ -407,11 +421,9 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState: array-like and BitGenerator object now passed to np.random.RandomState() as seed - Default None. - Returns ------- - np.random.RandomState or np.random if state is None + np.random.RandomState or np.random.Generator """ if ( @@ -434,12 +446,13 @@ def random_state(state: RandomState | None = None) -> np.random.RandomState: return np.random.RandomState(state) # type: ignore[arg-type] elif isinstance(state, np.random.RandomState): return state + elif isinstance(state, np.random.Generator): + return state elif state is None: - # error: Incompatible return value type (got Module, expected "RandomState") - return np.random # type: ignore[return-value] + return np.random else: raise ValueError( - "random_state must be an integer, array-like, a BitGenerator, " + "random_state must be an integer, array-like, a BitGenerator, Generator, " "a numpy RandomState, or None" ) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f770dd5831e84..ba63669464903 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5178,14 +5178,19 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. - random_state : int, array-like, BitGenerator, np.random.RandomState, optional - If int, array-like, or BitGenerator, seed for random number generator. - If np.random.RandomState, use as numpy RandomState object. + random_state : int, array-like, BitGenerator, np.random.RandomState, + np.random.Generator, optional. If int, array-like, or BitGenerator, seed for + random number generator. If np.random.RandomState or np.random.Generator, + use as given. .. versionchanged:: 1.1.0 - array-like and BitGenerator (for NumPy>=1.17) object now passed to - np.random.RandomState() as seed + array-like and BitGenerator object now passed to np.random.RandomState() + as seed + + .. versionchanged:: 1.4.0 + + np.random.Generator objects now accepted axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 7fd5f2d52d23c..529a5736e3379 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3211,9 +3211,14 @@ def sample( sampling probabilities after normalization within each group. Values must be non-negative with at least one positive element within each group. - random_state : int, array-like, BitGenerator, np.random.RandomState, optional - If int, array-like, or BitGenerator, seed for random number generator. - If np.random.RandomState, use as numpy RandomState object. + random_state : int, array-like, BitGenerator, np.random.RandomState, + np.random.Generator, optional. If int, array-like, or BitGenerator, seed for + random number generator. If np.random.RandomState or np.random.Generator, + use as given. + + .. versionchanged:: 1.4.0 + + np.random.Generator objects now accepted Returns ------- diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index 604788ba91633..43f0c6cc5b3b1 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -71,8 +71,8 @@ def test_sample_lengths(self, obj): def test_sample_invalid_random_state(self, obj): # Check for error when random_state argument invalid. msg = ( - "random_state must be an integer, array-like, a BitGenerator, a numpy " - "RandomState, or None" + "random_state must be an integer, array-like, a BitGenerator, Generator, " + "a numpy RandomState, or None" ) with pytest.raises(ValueError, match=msg): obj.sample(random_state="a_string") @@ -177,6 +177,22 @@ def test_sample_random_state(self, func_str, arg, frame_or_series): expected = obj.sample(n=3, random_state=com.random_state(eval(func_str)(arg))) tm.assert_equal(result, expected) + def test_sample_generator(self, frame_or_series): + # GH#38100 + obj = frame_or_series(np.arange(100)) + rng = np.random.default_rng() + + # Consecutive calls should advance the seed + result1 = obj.sample(n=50, random_state=rng) + result2 = obj.sample(n=50, random_state=rng) + assert not (result1.index.values == result2.index.values).all() + + # Matching generator initialization must give same result + # Consecutive calls should advance the seed + result1 = obj.sample(n=50, random_state=np.random.default_rng(11)) + result2 = obj.sample(n=50, random_state=np.random.default_rng(11)) + tm.assert_equal(result1, result2) + def test_sample_upsampling_without_replacement(self, frame_or_series): # GH#27451 From 88085d92ff708e0a6c3aaed3de7275046b44738a Mon Sep 17 00:00:00 2001 From: Matthew Zeitlin Date: Fri, 25 Jun 2021 19:00:06 -0400 Subject: [PATCH 2/5] Clean up diff --- pandas/core/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pandas/core/common.py b/pandas/core/common.py index c5ca6696dffe9..ab357bf4f7e75 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -403,7 +403,7 @@ def random_state( ... -def random_state(state: RandomState | None): +def random_state(state: RandomState | None = None): """ Helper function for processing random_state arguments. @@ -421,9 +421,11 @@ def random_state(state: RandomState | None): array-like and BitGenerator object now passed to np.random.RandomState() as seed + Default None. + Returns ------- - np.random.RandomState or np.random.Generator + np.random.RandomState or np.random.Generator. If state is None, returns np.random """ if ( From 08c4588df402b9a8c5329ea5bd169db5055983eb Mon Sep 17 00:00:00 2001 From: Matthew Zeitlin Date: Fri, 25 Jun 2021 20:24:16 -0400 Subject: [PATCH 3/5] Wrap in type checking --- pandas/_typing.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index 8071d9172c869..f2d622f34fd05 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -122,9 +122,17 @@ JSONSerializable = Optional[Union[PythonScalar, List, Dict]] Frequency = Union[str, "DateOffset"] Axes = Collection[Any] -RandomState = Union[ - int, ArrayLike, np.random.Generator, np.random.BitGenerator, np.random.RandomState -] + +if TYPE_CHECKING: + RandomState = Union[ + int, + ArrayLike, + np.random.Generator, + np.random.BitGenerator, + np.random.RandomState, + ] +else: + RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState] # dtypes NpDtype = Union[str, np.dtype] From 0fb9a6c0ae6857d38bd929c4f7adc04a9010ad08 Mon Sep 17 00:00:00 2001 From: Matthew Zeitlin Date: Fri, 25 Jun 2021 21:22:59 -0400 Subject: [PATCH 4/5] Fix common test --- pandas/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 93c95b3004876..f123f0ac9f5f8 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -84,7 +84,7 @@ def test_random_state(): # Error for floats or strings msg = ( - "random_state must be an integer, array-like, a BitGenerator, " + "random_state must be an integer, array-like, a BitGenerator, Generator, " "a numpy RandomState, or None" ) with pytest.raises(ValueError, match=msg): From 82ef3fb483d7ef0c9adb4224bbd6a4f9e96673f6 Mon Sep 17 00:00:00 2001 From: Matthew Zeitlin Date: Thu, 1 Jul 2021 20:31:28 -0400 Subject: [PATCH 5/5] Clean up TYPE_CHECKING handling --- pandas/_typing.py | 3 +-- pandas/core/common.py | 4 +++- pandas/core/generic.py | 2 +- pandas/core/groupby/groupby.py | 3 ++- pandas/core/sample.py | 4 ++-- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index ffc706304f4c9..d14699844094c 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -126,6 +126,7 @@ Frequency = Union[str, "DateOffset"] Axes = Collection[Any] +# BitGenerator isn't exposed until 1.18 if TYPE_CHECKING: RandomState = Union[ int, @@ -134,8 +135,6 @@ np.random.BitGenerator, np.random.RandomState, ] -else: - RandomState = Union[int, ArrayLike, np.random.Generator, np.random.RandomState] # dtypes NpDtype = Union[str, np.dtype] diff --git a/pandas/core/common.py b/pandas/core/common.py index ab357bf4f7e75..84d18f81292c9 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -32,7 +32,6 @@ AnyArrayLike, ArrayLike, NpDtype, - RandomState, Scalar, T, ) @@ -54,6 +53,9 @@ from pandas.core.dtypes.missing import isna if TYPE_CHECKING: + # Includes BitGenerator, which only exists >= 1.18 + from pandas._typing import RandomState + from pandas import Index diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 4b1e4675d4402..47a8d5248cd32 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -46,7 +46,6 @@ JSONSerializable, Level, Manager, - RandomState, Renamer, StorageOptions, T, @@ -158,6 +157,7 @@ from typing import Literal from pandas._libs.tslibs import BaseOffset + from pandas._typing import RandomState from pandas.core.frame import DataFrame from pandas.core.resample import Resampler diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 83160ea08229f..3ab6c62323f5e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -47,7 +47,6 @@ class providing the base-class of operations. FrameOrSeries, FrameOrSeriesUnion, IndexLabel, - RandomState, Scalar, T, final, @@ -113,6 +112,8 @@ class providing the base-class of operations. if TYPE_CHECKING: from typing import Literal + from pandas._typing import RandomState + _common_see_also = """ See Also -------- diff --git a/pandas/core/sample.py b/pandas/core/sample.py index 4798f385d523c..e4bad22e8e43c 100644 --- a/pandas/core/sample.py +++ b/pandas/core/sample.py @@ -109,7 +109,7 @@ def sample( size: int, replace: bool, weights: np.ndarray | None, - random_state: np.random.RandomState, + random_state: np.random.RandomState | np.random.Generator, ) -> np.ndarray: """ Randomly sample `size` indices in `np.arange(obj_len)` @@ -125,7 +125,7 @@ def sample( weights : np.ndarray[np.float64] or None If None, equal probability weighting, otherwise weights according to the vector normalized - random_state: np.random.RandomState + random_state: np.random.RandomState or np.random.Generator State used for the random sampling Returns