Skip to content

[BUG] Loosen random_state input restriction #32510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2020
6 changes: 5 additions & 1 deletion pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,11 @@ def random_state(state=None):
-------
np.random.RandomState
"""
if is_integer(state):
if (
is_integer(state)
or is_array_like(state)
or isinstance(state, np.random.BitGenerator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BitGenerator is somewhat new (NumPy>=1.17 IIRC). You'll need to put this in a conditional check. pandas.compat has some helpers.

):
return np.random.RandomState(state)
elif isinstance(state, np.random.RandomState):
return state
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def test_random_state():
# check with no arg random state
assert com.random_state() is np.random

# check array-like
state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32")
assert (
com.random_state(state_arr_like).uniform()
== npr.RandomState(state_arr_like).uniform()
)

# Check BitGenerators
assert (
com.random_state(npr.MT19937(3)).uniform()
== npr.RandomState(npr.MT19937(3)).uniform()
)
assert (
com.random_state(npr.PCG64(11)).uniform()
== npr.RandomState(npr.PCG64(11)).uniform()
)

# Error for floats or strings
msg = "random_state must be an integer, a numpy RandomState, or None"
with pytest.raises(ValueError, match=msg):
Expand Down