diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 21e59805fa143..6bd5723c5efcf 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -69,6 +69,7 @@ Other enhancements - `OptionError` is now exposed in `pandas.errors` (:issue:`27553`) - :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`) - Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`) +- :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/common.py b/pandas/core/common.py index 6230ee34bcd50..d9ce65d44177b 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -15,6 +15,7 @@ from pandas._libs import lib, tslibs from pandas._typing import T +from pandas.compat.numpy import _np_version_under1p17 from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike from pandas.core.dtypes.common import ( @@ -395,18 +396,30 @@ def random_state(state=None): Parameters ---------- - state : int, np.random.RandomState, None. - If receives an int, passes to np.random.RandomState() as seed. + state : int, array-like, BitGenerator (NumPy>=1.17), np.random.RandomState, None. + If receives an int, array-like, or BitGenerator, passes to + np.random.RandomState() as seed. If receives an np.random.RandomState object, just returns object. If receives `None`, returns np.random. If receives anything else, raises an informative ValueError. + + ..versionchanged:: 1.1.0 + + array-like and BitGenerator (for NumPy>=1.17) object now passed to + np.random.RandomState() as seed + Default None. Returns ------- np.random.RandomState + """ - if is_integer(state): + if ( + is_integer(state) + or is_array_like(state) + or (not _np_version_under1p17 and isinstance(state, np.random.BitGenerator)) + ): return np.random.RandomState(state) elif isinstance(state, np.random.RandomState): return state @@ -414,7 +427,10 @@ def random_state(state=None): return np.random else: raise ValueError( - "random_state must be an integer, a numpy RandomState, or None" + ( + "random_state must be an integer, array-like, a BitGenerator, " + "a numpy RandomState, or None" + ) ) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 8d56311331d4d..ba4d6dee7f984 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4794,9 +4794,16 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. - random_state : int or numpy.random.RandomState, optional - Seed for the random number generator (if int), or numpy RandomState - object. + random_state : int, array-like, BitGenerator, np.random.RandomState, optional + If int, array-like, or BitGenerator (NumPy>=1.17), seed for + random number generator + If np.random.RandomState, use as numpy RandomState object. + + ..versionchanged:: 1.1.0 + + array-like and BitGenerator (for NumPy>=1.17) object now passed to + np.random.RandomState() as seed + axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type (0 for Series and DataFrames). diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 9c759ff1414fe..6999dea6adfa3 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -3,11 +3,14 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + from pandas.core.dtypes.common import is_scalar import pandas as pd from pandas import DataFrame, MultiIndex, Series, date_range import pandas._testing as tm +import pandas.core.common as com # ---------------------------------------------------------------------- # Generic types test cases @@ -641,6 +644,29 @@ def test_sample(sel): with pytest.raises(ValueError): df.sample(1, weights=s4) + @pytest.mark.parametrize( + "func_str,arg", + [ + ("np.array", [2, 3, 1, 0]), + pytest.param( + "np.random.MT19937", + 3, + marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), + ), + pytest.param( + "np.random.PCG64", + 11, + marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), + ), + ], + ) + def test_sample_random_state(self, func_str, arg): + # GH32503 + df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)}) + result = df.sample(n=3, random_state=eval(func_str)(arg)) + expected = df.sample(n=3, random_state=com.random_state(eval(func_str)(arg))) + tm.assert_frame_equal(result, expected) + def test_squeeze(self): # noop for s in [tm.makeFloatSeries(), tm.makeStringSeries(), tm.makeObjectSeries()]: diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 186c735a0bff9..bcfed2d0d3a10 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -6,6 +6,8 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + import pandas as pd from pandas import Series, Timestamp from pandas.core import ops @@ -59,8 +61,31 @@ def test_random_state(): # check with no arg random state assert com.random_state() is np.random + # check array-like + # GH32503 + state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32") + assert ( + com.random_state(state_arr_like).uniform() + == npr.RandomState(state_arr_like).uniform() + ) + + # Check BitGenerators + # GH32503 + if not _np_version_under1p17: + assert ( + com.random_state(npr.MT19937(3)).uniform() + == npr.RandomState(npr.MT19937(3)).uniform() + ) + assert ( + com.random_state(npr.PCG64(11)).uniform() + == npr.RandomState(npr.PCG64(11)).uniform() + ) + # Error for floats or strings - msg = "random_state must be an integer, a numpy RandomState, or None" + msg = ( + "random_state must be an integer, array-like, a BitGenerator, " + "a numpy RandomState, or None" + ) with pytest.raises(ValueError, match=msg): com.random_state("test")