From 876ee77b04bf51d0e15402a8a086d0fcad39cf22 Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Fri, 6 Mar 2020 21:49:38 -0500 Subject: [PATCH 1/7] [BUG] Loosen random_state input restriction Alllow for array-like as well as BitGenerator inputs Addresses: GH32503 --- pandas/core/common.py | 6 +++++- pandas/tests/test_common.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/pandas/core/common.py b/pandas/core/common.py index 6230ee34bcd50..ec13947fda24f 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -406,7 +406,11 @@ def random_state(state=None): ------- np.random.RandomState """ - if is_integer(state): + if ( + is_integer(state) + or is_array_like(state) + or isinstance(state, np.random.BitGenerator) + ): return np.random.RandomState(state) elif isinstance(state, np.random.RandomState): return state diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 186c735a0bff9..9008729554cf6 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -59,6 +59,23 @@ def test_random_state(): # check with no arg random state assert com.random_state() is np.random + # check array-like + state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32") + assert ( + com.random_state(state_arr_like).uniform() + == npr.RandomState(state_arr_like).uniform() + ) + + # Check BitGenerators + assert ( + com.random_state(npr.MT19937(3)).uniform() + == npr.RandomState(npr.MT19937(3)).uniform() + ) + assert ( + com.random_state(npr.PCG64(11)).uniform() + == npr.RandomState(npr.PCG64(11)).uniform() + ) + # Error for floats or strings msg = "random_state must be an integer, a numpy RandomState, or None" with pytest.raises(ValueError, match=msg): From fb98ba303f6e00296b52f343517221e2b647f4ba Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Mon, 9 Mar 2020 20:32:30 -0400 Subject: [PATCH 2/7] Update whatsnew and docstrings; Add compatibility check --- doc/source/whatsnew/v1.1.0.rst | 1 + pandas/core/common.py | 12 +++++++++--- pandas/tests/test_common.py | 19 +++++++++++-------- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 44deab25db695..83a6e391b855d 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -67,6 +67,7 @@ Other enhancements - When writing directly to a sqlite connection :func:`to_sql` now supports the ``multi`` method (:issue:`29921`) - `OptionError` is now exposed in `pandas.errors` (:issue:`27553`) - :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`) +- :func:`core.common.random_state` will now pass array-like and BitGenerator objects through to `np.random.RandomState` as seeds (:issue:`32503`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/common.py b/pandas/core/common.py index ec13947fda24f..1493635f55064 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -15,6 +15,7 @@ from pandas._libs import lib, tslibs from pandas._typing import T +from pandas.compat.numpy import _np_version_under1p17 from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike from pandas.core.dtypes.common import ( @@ -395,8 +396,9 @@ def random_state(state=None): Parameters ---------- - state : int, np.random.RandomState, None. - If receives an int, passes to np.random.RandomState() as seed. + state : int, array-like, BitGenerator (NumPy>=1.17), np.random.RandomState, None. + If receives an int, array-like, or BitGenerator, passes to + np.random.RandomState() as seed. If receives an np.random.RandomState object, just returns object. If receives `None`, returns np.random. If receives anything else, raises an informative ValueError. @@ -405,11 +407,15 @@ def random_state(state=None): Returns ------- np.random.RandomState + + ..versionchanged:: 1.1.0 + array-like and BitGenerator (for NumPy>=1.17) object now passed to + np.random.RandomState() as seed """ if ( is_integer(state) or is_array_like(state) - or isinstance(state, np.random.BitGenerator) + or (not _np_version_under1p17 and isinstance(state, np.random.BitGenerator)) ): return np.random.RandomState(state) elif isinstance(state, np.random.RandomState): diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 9008729554cf6..59aa000d30c7e 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -6,6 +6,8 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + import pandas as pd from pandas import Series, Timestamp from pandas.core import ops @@ -67,14 +69,15 @@ def test_random_state(): ) # Check BitGenerators - assert ( - com.random_state(npr.MT19937(3)).uniform() - == npr.RandomState(npr.MT19937(3)).uniform() - ) - assert ( - com.random_state(npr.PCG64(11)).uniform() - == npr.RandomState(npr.PCG64(11)).uniform() - ) + if not _np_version_under1p17: + assert ( + com.random_state(npr.MT19937(3)).uniform() + == npr.RandomState(npr.MT19937(3)).uniform() + ) + assert ( + com.random_state(npr.PCG64(11)).uniform() + == npr.RandomState(npr.PCG64(11)).uniform() + ) # Error for floats or strings msg = "random_state must be an integer, a numpy RandomState, or None" From 80f238f1fc345e0385f32b1010275abea55859ba Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Wed, 11 Mar 2020 20:32:41 -0400 Subject: [PATCH 3/7] Add tests to pandas.core.generic for .sample; Update whatsnew and docs --- doc/source/whatsnew/v1.1.0.rst | 2 +- pandas/core/common.py | 5 ++++- pandas/core/generic.py | 7 ++++--- pandas/tests/generic/test_generic.py | 21 +++++++++++++++++++++ 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 6ebcbb275b1cb..b0c316d22e6ed 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -68,7 +68,7 @@ Other enhancements - `OptionError` is now exposed in `pandas.errors` (:issue:`27553`) - :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`) - Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`) -- :func:`core.common.random_state` will now pass array-like and BitGenerator objects through to `np.random.RandomState` as seeds (:issue:`32503`) +- :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/common.py b/pandas/core/common.py index 1493635f55064..6716585298752 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -424,7 +424,10 @@ def random_state(state=None): return np.random else: raise ValueError( - "random_state must be an integer, a numpy RandomState, or None" + ( + "random_state must be an integer, array-like, a BitGenerator, " + "a numpy RandomState, or None" + ) ) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index e6c5ac9dbf733..07852b833284e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4788,9 +4788,10 @@ def sample( If weights do not sum to 1, they will be normalized to sum to 1. Missing values in the weights column will be treated as zero. Infinite values not allowed. - random_state : int or numpy.random.RandomState, optional - Seed for the random number generator (if int), or numpy RandomState - object. + random_state : int, array-like, BitGenerator, np.random.RandomState, optional + If int, array-like, or BitGenerator (NumPy>=1.17), seed for + random number generator + If np.random.RandomState, use as numpy RandomState object. axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type (0 for Series and DataFrames). diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 1b6cb8447c76d..4c58f1fdcfea8 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -3,11 +3,14 @@ import numpy as np import pytest +from pandas.compat.numpy import _np_version_under1p17 + from pandas.core.dtypes.common import is_scalar import pandas as pd from pandas import DataFrame, MultiIndex, Series, date_range import pandas._testing as tm +import pandas.core.common as com # ---------------------------------------------------------------------- # Generic types test cases @@ -654,6 +657,24 @@ def test_sample(sel): with pytest.raises(ValueError): df.sample(1, weights=s4) + # Check that random_state arguments are processed correctly + df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)}) + seed_arr = np.random.randint(4, size=10) + tm.assert_frame_equal( + df.sample(n=3, random_state=com.random_state(seed_arr)), + df.sample(n=3, random_state=seed_arr), + ) + + if not _np_version_under1p17: + tm.assert_frame_equal( + df.sample(n=3, random_state=com.random_state(np.random.MT19937(3))), + df.sample(n=3, random_state=np.random.MT19937(3)), + ) + tm.assert_frame_equal( + df.sample(n=3, random_state=com.random_state(np.random.PCG64(11))), + df.sample(n=3, random_state=np.random.PCG64(11)), + ) + def test_squeeze(self): # noop for s in [tm.makeFloatSeries(), tm.makeStringSeries(), tm.makeObjectSeries()]: From 33301c7219c583d5978bb712cfced8ce21967be1 Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Wed, 11 Mar 2020 22:24:27 -0400 Subject: [PATCH 4/7] Fix test_common to reflect new ValueError msg --- pandas/tests/test_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 59aa000d30c7e..9a0d03c61716d 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -80,7 +80,10 @@ def test_random_state(): ) # Error for floats or strings - msg = "random_state must be an integer, a numpy RandomState, or None" + msg = ( + "random_state must be an integer, array-like, a BitGenerator, " + "a numpy RandomState, or None" + ) with pytest.raises(ValueError, match=msg): com.random_state("test") From c97508d92be19623529596d3c67aac41840a8a8e Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Sat, 14 Mar 2020 21:25:14 -0400 Subject: [PATCH 5/7] Move sample-random_state tests into separate test --- pandas/tests/generic/test_generic.py | 34 ++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index e2f397e2a7aaf..2bc3509209ab7 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -644,24 +644,30 @@ def test_sample(sel): with pytest.raises(ValueError): df.sample(1, weights=s4) - # Check that random_state arguments are processed correctly + @pytest.mark.parametrize( + "func,arg", + [ + (np.array, [2, 3, 1, 0]), + pytest.param( + np.random.MT19937, + 3, + marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), + ), + pytest.param( + np.random.PCG64, + 11, + marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), + ), + ], + ) + def test_sample_random_state(self, func, arg): + # GH32503 df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)}) - seed_arr = np.random.randint(4, size=10) tm.assert_frame_equal( - df.sample(n=3, random_state=com.random_state(seed_arr)), - df.sample(n=3, random_state=seed_arr), + df.sample(n=3, random_state=com.random_state(func(arg))), + df.sample(n=3, random_state=func(arg)), ) - if not _np_version_under1p17: - tm.assert_frame_equal( - df.sample(n=3, random_state=com.random_state(np.random.MT19937(3))), - df.sample(n=3, random_state=np.random.MT19937(3)), - ) - tm.assert_frame_equal( - df.sample(n=3, random_state=com.random_state(np.random.PCG64(11))), - df.sample(n=3, random_state=np.random.PCG64(11)), - ) - def test_squeeze(self): # noop for s in [tm.makeFloatSeries(), tm.makeStringSeries(), tm.makeObjectSeries()]: From ba736ce22264c2c0f5844fc054e0bfdd25ef8330 Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Sat, 14 Mar 2020 21:47:12 -0400 Subject: [PATCH 6/7] Change pytest params to to delay evaluation for skipped tests --- pandas/tests/generic/test_generic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 2bc3509209ab7..9c2261d2c9140 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -645,27 +645,27 @@ def test_sample(sel): df.sample(1, weights=s4) @pytest.mark.parametrize( - "func,arg", + "func_str,arg", [ - (np.array, [2, 3, 1, 0]), + ("np.array", [2, 3, 1, 0]), pytest.param( - np.random.MT19937, + "np.random.MT19937", 3, marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), ), pytest.param( - np.random.PCG64, + "np.random.PCG64", 11, marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"), ), ], ) - def test_sample_random_state(self, func, arg): + def test_sample_random_state(self, func_str, arg): # GH32503 df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)}) tm.assert_frame_equal( - df.sample(n=3, random_state=com.random_state(func(arg))), - df.sample(n=3, random_state=func(arg)), + df.sample(n=3, random_state=com.random_state(eval(func_str)(arg))), + df.sample(n=3, random_state=eval(func_str)(arg)), ) def test_squeeze(self): From 191702685edbc46e6cdb1f20a77f4ef6204af0e2 Mon Sep 17 00:00:00 2001 From: Mike Kutzma Date: Mon, 16 Mar 2020 21:59:32 -0400 Subject: [PATCH 7/7] Fix docstrings and test format --- pandas/core/common.py | 9 ++++++--- pandas/core/generic.py | 6 ++++++ pandas/tests/generic/test_generic.py | 7 +++---- pandas/tests/test_common.py | 2 ++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pandas/core/common.py b/pandas/core/common.py index 6716585298752..d9ce65d44177b 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -402,15 +402,18 @@ def random_state(state=None): If receives an np.random.RandomState object, just returns object. If receives `None`, returns np.random. If receives anything else, raises an informative ValueError. + + ..versionchanged:: 1.1.0 + + array-like and BitGenerator (for NumPy>=1.17) object now passed to + np.random.RandomState() as seed + Default None. Returns ------- np.random.RandomState - ..versionchanged:: 1.1.0 - array-like and BitGenerator (for NumPy>=1.17) object now passed to - np.random.RandomState() as seed """ if ( is_integer(state) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 4e67297cf186b..ba4d6dee7f984 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -4798,6 +4798,12 @@ def sample( If int, array-like, or BitGenerator (NumPy>=1.17), seed for random number generator If np.random.RandomState, use as numpy RandomState object. + + ..versionchanged:: 1.1.0 + + array-like and BitGenerator (for NumPy>=1.17) object now passed to + np.random.RandomState() as seed + axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type (0 for Series and DataFrames). diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py index 9c2261d2c9140..6999dea6adfa3 100644 --- a/pandas/tests/generic/test_generic.py +++ b/pandas/tests/generic/test_generic.py @@ -663,10 +663,9 @@ def test_sample(sel): def test_sample_random_state(self, func_str, arg): # GH32503 df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)}) - tm.assert_frame_equal( - df.sample(n=3, random_state=com.random_state(eval(func_str)(arg))), - df.sample(n=3, random_state=eval(func_str)(arg)), - ) + result = df.sample(n=3, random_state=eval(func_str)(arg)) + expected = df.sample(n=3, random_state=com.random_state(eval(func_str)(arg))) + tm.assert_frame_equal(result, expected) def test_squeeze(self): # noop diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py index 9a0d03c61716d..bcfed2d0d3a10 100644 --- a/pandas/tests/test_common.py +++ b/pandas/tests/test_common.py @@ -62,6 +62,7 @@ def test_random_state(): assert com.random_state() is np.random # check array-like + # GH32503 state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32") assert ( com.random_state(state_arr_like).uniform() @@ -69,6 +70,7 @@ def test_random_state(): ) # Check BitGenerators + # GH32503 if not _np_version_under1p17: assert ( com.random_state(npr.MT19937(3)).uniform()