diff --git a/pandas/_testing.py b/pandas/_testing.py index 33ec4e4886aa6..1b6a350e814f4 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -2125,7 +2125,7 @@ def makeMissingCustomDataframe( Density : float, optional Float in (0, 1) that gives the percentage of non-missing numbers in the DataFrame. - random_state : {np.random.RandomState, int}, optional + random_state : {np.random.RandomState, int, integer array like}, optional Random number generator or random seed. See makeCustomDataframe for descriptions of the rest of the parameters. diff --git a/pandas/core/common.py b/pandas/core/common.py index 6230ee34bcd50..aa3b417bdc31a 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -22,6 +22,7 @@ is_bool_dtype, is_extension_array_dtype, is_integer, + is_list_like, ) from pandas.core.dtypes.generic import ABCIndex, ABCIndexClass, ABCSeries from pandas.core.dtypes.inference import _iterable_not_string @@ -408,13 +409,16 @@ def random_state(state=None): """ if is_integer(state): return np.random.RandomState(state) + elif is_list_like(state) and all(is_integer(item) for item in state): + return np.random.RandomState(state) elif isinstance(state, np.random.RandomState): return state elif state is None: return np.random else: raise ValueError( - "random_state must be an integer, a numpy RandomState, or None" + "random_state must be an integer, integer array like," + "a numpy RandomState, or None" ) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index f7196073162df..c121d426274b8 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4691,7 +4691,7 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. - random_state : int or numpy.random.RandomState, optional + random_state : int, integer array like or numpy.random.RandomState, optional Seed for the random number generator (if int), or numpy RandomState object. axis : {0 or ‘index’, 1 or ‘columns’, None}, default None diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 186c735a0bff9..da5d955e2c71d 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -56,11 +56,18 @@ def test_random_state(): state2 = npr.RandomState(10) assert com.random_state(state2).uniform() == npr.RandomState(10).uniform() + # Check with array ike + state3 = com.random_state([1, 5, 10]) + assert com.random_state(state3).uniform() == npr.RandomState([1, 5, 10]).uniform() + # check with no arg random state assert com.random_state() is np.random # Error for floats or strings - msg = "random_state must be an integer, a numpy RandomState, or None" + msg = ( + "random_state must be an integer, integer array like," + "a numpy RandomState, or None" + ) with pytest.raises(ValueError, match=msg): com.random_state("test")