Skip to content

[BUG] Loosen random_state input restriction #32510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2020
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Other enhancements
- `OptionError` is now exposed in `pandas.errors` (:issue:`27553`)
- :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`)
- Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`)
- :func:`core.common.random_state` will now pass array-like and BitGenerator objects through to `np.random.RandomState` as seeds (:issue:`32503`)
-

.. ---------------------------------------------------------------------------
Expand Down
16 changes: 13 additions & 3 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pandas._libs import lib, tslibs
from pandas._typing import T
from pandas.compat.numpy import _np_version_under1p17

from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -395,8 +396,9 @@ def random_state(state=None):

Parameters
----------
state : int, np.random.RandomState, None.
If receives an int, passes to np.random.RandomState() as seed.
state : int, array-like, BitGenerator (NumPy>=1.17), np.random.RandomState, None.
If receives an int, array-like, or BitGenerator, passes to
np.random.RandomState() as seed.
If receives an np.random.RandomState object, just returns object.
If receives `None`, returns np.random.
If receives anything else, raises an informative ValueError.
Expand All @@ -405,8 +407,16 @@ def random_state(state=None):
Returns
-------
np.random.RandomState

..versionchanged:: 1.1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be in the Paramters section, just below the 'If receives anything'

array-like and BitGenerator (for NumPy>=1.17) object now passed to
np.random.RandomState() as seed
"""
if is_integer(state):
if (
is_integer(state)
or is_array_like(state)
or (not _np_version_under1p17 and isinstance(state, np.random.BitGenerator))
):
return np.random.RandomState(state)
elif isinstance(state, np.random.RandomState):
return state
Expand Down
20 changes: 20 additions & 0 deletions pandas/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pytest

from pandas.compat.numpy import _np_version_under1p17

import pandas as pd
from pandas import Series, Timestamp
from pandas.core import ops
Expand Down Expand Up @@ -59,6 +61,24 @@ def test_random_state():
# check with no arg random state
assert com.random_state() is np.random

# check array-like
state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32")
assert (
com.random_state(state_arr_like).uniform()
== npr.RandomState(state_arr_like).uniform()
)

# Check BitGenerators
if not _np_version_under1p17:
assert (
com.random_state(npr.MT19937(3)).uniform()
== npr.RandomState(npr.MT19937(3)).uniform()
)
assert (
com.random_state(npr.PCG64(11)).uniform()
== npr.RandomState(npr.PCG64(11)).uniform()
)

# Error for floats or strings
msg = "random_state must be an integer, a numpy RandomState, or None"
with pytest.raises(ValueError, match=msg):
Expand Down