From 1e13515e7fb9490cee95d323f635de898712d9e5 Mon Sep 17 00:00:00 2001 From: James Myatt Date: Tue, 6 Apr 2021 14:48:16 +0100 Subject: [PATCH 1/3] BUG: Support numpy.random.Generator as random_state input Workaround for #38100 --- pandas/core/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/core/common.py b/pandas/core/common.py index 98606f5d3d240..a77dab1719a21 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -425,6 +425,8 @@ def random_state(state=None): or (not np_version_under1p18 and isinstance(state, np.random.BitGenerator)) ): return np.random.RandomState(state) + elif not np_version_under1p18 and isinstance(state, np.random.Generator): + return np.random.RandomState(state.bit_generator) elif isinstance(state, np.random.RandomState): return state elif state is None: From 96d91033655d3c5a8c6a6f9c65dba6d754e72865 Mon Sep 17 00:00:00 2001 From: James Myatt Date: Tue, 6 Apr 2021 14:58:05 +0100 Subject: [PATCH 2/3] TST: Check Generator input to com.random_state --- pandas/tests/test_common.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 696395e50dd02..5f8d0acc23bf7 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -82,6 +82,18 @@ def test_random_state(): == npr.RandomState(npr.PCG64(11)).uniform() ) + # Check Generators + # GH38100 + if not np_version_under1p17: + rng1 = npr.default_rng(4) + rng2 = npr.default_rng(4) + assert ( + com.random_state(rng1).uniform() + == npr.RandomState(rng2.bit_generator).uniform() + ) + assert rng1.uniform() == rng2.uniform() + assert com.random_state(rng1).uniform() != rng2.uniform() + # Error for floats or strings msg = ( "random_state must be an integer, array-like, a BitGenerator, " From 9f3f550bdb26bb451c4791e9bf90ee1405f2518a Mon Sep 17 00:00:00 2001 From: James Myatt Date: Tue, 6 Apr 2021 21:21:43 +0100 Subject: [PATCH 3/3] TST: Update test_common.test_random_state --- pandas/tests/test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 5f8d0acc23bf7..28b415054bcad 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -91,8 +91,8 @@ def test_random_state(): com.random_state(rng1).uniform() == npr.RandomState(rng2.bit_generator).uniform() ) + # Check base RNGs have advanced correctly assert rng1.uniform() == rng2.uniform() - assert com.random_state(rng1).uniform() != rng2.uniform() # Error for floats or strings msg = (