Skip to content

Commit eb40124

Browse files
mikekutzmaSeeminSyed
authored andcommitted
[BUG] Loosen random_state input restriction (pandas-dev#32510)
1 parent 4a8f36d commit eb40124

File tree

5 files changed

+83
-8
lines changed

5 files changed

+83
-8
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Other enhancements
6969
- `OptionError` is now exposed in `pandas.errors` (:issue:`27553`)
7070
- :func:`timedelta_range` will now infer a frequency when passed ``start``, ``stop``, and ``periods`` (:issue:`32377`)
7171
- Positional slicing on a :class:`IntervalIndex` now supports slices with ``step > 1`` (:issue:`31658`)
72+
- :meth:`DataFrame.sample` will now also allow array-like and BitGenerator objects to be passed to ``random_state`` as seeds (:issue:`32503`)
7273
-
7374

7475
.. ---------------------------------------------------------------------------

pandas/core/common.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pandas._libs import lib, tslibs
1717
from pandas._typing import T
18+
from pandas.compat.numpy import _np_version_under1p17
1819

1920
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
2021
from pandas.core.dtypes.common import (
@@ -392,26 +393,41 @@ def random_state(state=None):
392393
393394
Parameters
394395
----------
395-
state : int, np.random.RandomState, None.
396-
If receives an int, passes to np.random.RandomState() as seed.
396+
state : int, array-like, BitGenerator (NumPy>=1.17), np.random.RandomState, None.
397+
If receives an int, array-like, or BitGenerator, passes to
398+
np.random.RandomState() as seed.
397399
If receives an np.random.RandomState object, just returns object.
398400
If receives `None`, returns np.random.
399401
If receives anything else, raises an informative ValueError.
402+
403+
..versionchanged:: 1.1.0
404+
405+
array-like and BitGenerator (for NumPy>=1.17) object now passed to
406+
np.random.RandomState() as seed
407+
400408
Default None.
401409
402410
Returns
403411
-------
404412
np.random.RandomState
413+
405414
"""
406-
if is_integer(state):
415+
if (
416+
is_integer(state)
417+
or is_array_like(state)
418+
or (not _np_version_under1p17 and isinstance(state, np.random.BitGenerator))
419+
):
407420
return np.random.RandomState(state)
408421
elif isinstance(state, np.random.RandomState):
409422
return state
410423
elif state is None:
411424
return np.random
412425
else:
413426
raise ValueError(
414-
"random_state must be an integer, a numpy RandomState, or None"
427+
(
428+
"random_state must be an integer, array-like, a BitGenerator, "
429+
"a numpy RandomState, or None"
430+
)
415431
)
416432

417433

pandas/core/generic.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -4794,9 +4794,16 @@ def sample(
47944794
If weights do not sum to 1, they will be normalized to sum to 1.
47954795
Missing values in the weights column will be treated as zero.
47964796
Infinite values not allowed.
4797-
random_state : int or numpy.random.RandomState, optional
4798-
Seed for the random number generator (if int), or numpy RandomState
4799-
object.
4797+
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
4798+
If int, array-like, or BitGenerator (NumPy>=1.17), seed for
4799+
random number generator
4800+
If np.random.RandomState, use as numpy RandomState object.
4801+
4802+
..versionchanged:: 1.1.0
4803+
4804+
array-like and BitGenerator (for NumPy>=1.17) object now passed to
4805+
np.random.RandomState() as seed
4806+
48004807
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None
48014808
Axis to sample. Accepts axis number or name. Default is stat axis
48024809
for given data type (0 for Series and DataFrames).

pandas/tests/generic/test_generic.py

+26
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import numpy as np
44
import pytest
55

6+
from pandas.compat.numpy import _np_version_under1p17
7+
68
from pandas.core.dtypes.common import is_scalar
79

810
import pandas as pd
911
from pandas import DataFrame, MultiIndex, Series, date_range
1012
import pandas._testing as tm
13+
import pandas.core.common as com
1114

1215
# ----------------------------------------------------------------------
1316
# Generic types test cases
@@ -641,6 +644,29 @@ def test_sample(sel):
641644
with pytest.raises(ValueError):
642645
df.sample(1, weights=s4)
643646

647+
@pytest.mark.parametrize(
648+
"func_str,arg",
649+
[
650+
("np.array", [2, 3, 1, 0]),
651+
pytest.param(
652+
"np.random.MT19937",
653+
3,
654+
marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"),
655+
),
656+
pytest.param(
657+
"np.random.PCG64",
658+
11,
659+
marks=pytest.mark.skipif(_np_version_under1p17, reason="NumPy<1.17"),
660+
),
661+
],
662+
)
663+
def test_sample_random_state(self, func_str, arg):
664+
# GH32503
665+
df = pd.DataFrame({"col1": range(10, 20), "col2": range(20, 30)})
666+
result = df.sample(n=3, random_state=eval(func_str)(arg))
667+
expected = df.sample(n=3, random_state=com.random_state(eval(func_str)(arg)))
668+
tm.assert_frame_equal(result, expected)
669+
644670
def test_squeeze(self):
645671
# noop
646672
for s in [tm.makeFloatSeries(), tm.makeStringSeries(), tm.makeObjectSeries()]:

pandas/tests/test_common.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pytest
88

9+
from pandas.compat.numpy import _np_version_under1p17
10+
911
import pandas as pd
1012
from pandas import Series, Timestamp
1113
from pandas.core import ops
@@ -59,8 +61,31 @@ def test_random_state():
5961
# check with no arg random state
6062
assert com.random_state() is np.random
6163

64+
# check array-like
65+
# GH32503
66+
state_arr_like = npr.randint(0, 2 ** 31, size=624, dtype="uint32")
67+
assert (
68+
com.random_state(state_arr_like).uniform()
69+
== npr.RandomState(state_arr_like).uniform()
70+
)
71+
72+
# Check BitGenerators
73+
# GH32503
74+
if not _np_version_under1p17:
75+
assert (
76+
com.random_state(npr.MT19937(3)).uniform()
77+
== npr.RandomState(npr.MT19937(3)).uniform()
78+
)
79+
assert (
80+
com.random_state(npr.PCG64(11)).uniform()
81+
== npr.RandomState(npr.PCG64(11)).uniform()
82+
)
83+
6284
# Error for floats or strings
63-
msg = "random_state must be an integer, a numpy RandomState, or None"
85+
msg = (
86+
"random_state must be an integer, array-like, a BitGenerator, "
87+
"a numpy RandomState, or None"
88+
)
6489
with pytest.raises(ValueError, match=msg):
6590
com.random_state("test")
6691

0 commit comments

Comments
 (0)