Skip to content

[BUG] Loosen random_state input restriction #32510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2020
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Other enhancements
- `OptionError` is now exposed in `pandas.errors` (:issue:`27553`)
- :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`)
- Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`)
- :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`)
-

.. ---------------------------------------------------------------------------
Expand Down
21 changes: 17 additions & 4 deletions pandas/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pandas._libs import lib, tslibs
from pandas._typing import T
from pandas.compat.numpy import _np_version_under1p17

from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -395,8 +396,9 @@ def random_state(state=None):

Parameters
----------
state : int, np.random.RandomState, None.
If receives an int, passes to np.random.RandomState() as seed.
state : int, array-like, BitGenerator (NumPy>=1.17), np.random.RandomState, None.
If receives an int, array-like, or BitGenerator, passes to
np.random.RandomState() as seed.
If receives an np.random.RandomState object, just returns object.
If receives `None`, returns np.random.
If receives anything else, raises an informative ValueError.
Expand All @@ -405,16 +407,27 @@ def random_state(state=None):
Returns
-------
np.random.RandomState

..versionchanged:: 1.1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be in the Paramters section, just below the 'If receives anything'

array-like and BitGenerator (for NumPy>=1.17) object now passed to
np.random.RandomState() as seed
"""
if is_integer(state):
if (
is_integer(state)
or is_array_like(state)
or (not _np_version_under1p17 and isinstance(state, np.random.BitGenerator))
):
return np.random.RandomState(state)
elif isinstance(state, np.random.RandomState):
return state
elif state is None:
return np.random
else:
raise ValueError(
"random_state must be an integer, a numpy RandomState, or None"
(
"random_state must be an integer, array-like, a BitGenerator, "
"a numpy RandomState, or None"
)
)


Expand Down
7 changes: 4 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4794,9 +4794,10 @@ def sample(
If weights do not sum to 1, they will be normalized to sum to 1.
Missing values in the weights column will be treated as zero.
Infinite values not allowed.
random_state : int or numpy.random.RandomState, optional
Seed for the random number generator (if int), or numpy RandomState
object.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator (NumPy>=1.17), seed for
random number generator
If np.random.RandomState, use as numpy RandomState object.
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
Axis to sample. Accepts axis number or name. Default is stat axis
for given data type (0 for Series and DataFrames).
Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/generic/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import numpy as np
import pytest

from pandas.compat.numpy import _np_version_under1p17

from pandas.core.dtypes.common import is_scalar

import pandas as pd
from pandas import DataFrame, MultiIndex, Series, date_range
import pandas._testing as tm
import pandas.core.common as com

# ----------------------------------------------------------------------
# Generic types test cases
Expand Down Expand Up @@ -641,6 +644,30 @@ def test_sample(sel):
with pytest.raises(ValueError):
df.sample(1, weights=s4)

@pytest.mark.parametrize(
"func_str,arg",
[
("np.array", [2, 3, 1, 0]),
pytest.param(
"np.random.MT19937",
3,
marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"),
),
pytest.param(
"np.random.PCG64",
11,
marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"),
),
],
)
def test_sample_random_state(self, func_str, arg):
# GH32503
df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)})
tm.assert_frame_equal(
df.sample(n=3, random_state=com.random_state(eval(func_str)(arg))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you do

result =
expected =
tm.assert_frame_equal(result, expected)

df.sample(n=3, random_state=eval(func_str)(arg)),
)

def test_squeeze(self):
# noop
for s in [tm.makeFloatSeries(), tm.makeStringSeries(), tm.makeObjectSeries()]:
Expand Down
25 changes: 24 additions & 1 deletion pandas/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pytest

from pandas.compat.numpy import _np_version_under1p17

import pandas as pd
from pandas import Series, Timestamp
from pandas.core import ops
Expand Down Expand Up @@ -59,8 +61,29 @@ def test_random_state():
# check with no arg random state
assert com.random_state() is np.random

# check array-like
state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32")
assert (
com.random_state(state_arr_like).uniform()
== npr.RandomState(state_arr_like).uniform()
)

# Check BitGenerators
if not _np_version_under1p17:
assert (
com.random_state(npr.MT19937(3)).uniform()
== npr.RandomState(npr.MT19937(3)).uniform()
)
assert (
com.random_state(npr.PCG64(11)).uniform()
== npr.RandomState(npr.PCG64(11)).uniform()
)

# Error for floats or strings
msg = "random_state must be an integer, a numpy RandomState, or None"
msg = (
"random_state must be an integer, array-like, a BitGenerator, "
"a numpy RandomState, or None"
)
with pytest.raises(ValueError, match=msg):
com.random_state("test")

Expand Down