Skip to content

Commit 0e19474

Browse files
committed
[BUG] Add int-array-like into random state
1 parent 970499d commit 0e19474

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

pandas/_testing.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,11 @@ def _get_ilevel_values(index, level):
706706
if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex):
707707
assert_attr_equal("freq", left, right, obj=obj)
708708
if isinstance(left, pd.IntervalIndex) or isinstance(right, pd.IntervalIndex):
709-
assert_interval_array_equal(left._values, right._values)
709+
assert_interval_array_equal(left.values, right.values)
710710

711711
if check_categorical:
712712
if is_categorical_dtype(left) or is_categorical_dtype(right):
713-
assert_categorical_equal(left._values, right._values, obj=f"{obj} category")
713+
assert_categorical_equal(left.values, right.values, obj=f"{obj} category")
714714

715715

716716
def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"):
@@ -883,7 +883,7 @@ def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray")
883883
def assert_period_array_equal(left, right, obj="PeriodArray"):
884884
_check_isinstance(left, right, PeriodArray)
885885

886-
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
886+
assert_numpy_array_equal(left._data, right._data, obj=f"{obj}.values")
887887
assert_attr_equal("freq", left, right, obj=obj)
888888

889889

@@ -1170,10 +1170,10 @@ def assert_series_equal(
11701170

11711171
# datetimelike may have different objects (e.g. datetime.datetime
11721172
# vs Timestamp) but will compare equal
1173-
if not Index(left._values).equals(Index(right._values)):
1173+
if not Index(left.values).equals(Index(right.values)):
11741174
msg = (
1175-
f"[datetimelike_compat=True] {left._values} "
1176-
f"is not equal to {right._values}."
1175+
f"[datetimelike_compat=True] {left.values} "
1176+
f"is not equal to {right.values}."
11771177
)
11781178
raise AssertionError(msg)
11791179
else:
@@ -1212,8 +1212,8 @@ def assert_series_equal(
12121212
if check_categorical:
12131213
if is_categorical_dtype(left) or is_categorical_dtype(right):
12141214
assert_categorical_equal(
1215-
left._values,
1216-
right._values,
1215+
left.values,
1216+
right.values,
12171217
obj=f"{obj} category",
12181218
check_category_order=check_category_order,
12191219
)
@@ -2125,7 +2125,7 @@ def makeMissingCustomDataframe(
21252125
Density : float, optional
21262126
Float in (0, 1) that gives the percentage of non-missing numbers in
21272127
the DataFrame.
2128-
random_state : {np.random.RandomState, int}, optional
2128+
random_state : {np.random.RandomState, int, integer array like}, optional
21292129
Random number generator or random seed.
21302130
21312131
See makeCustomDataframe for descriptions of the rest of the parameters.

pandas/core/common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -408,13 +408,16 @@ def random_state(state=None):
408408
"""
409409
if is_integer(state):
410410
return np.random.RandomState(state)
411+
elif is_array_like(state) and all(is_integer(item) for item in state):
412+
return np.random.RandomState(state)
411413
elif isinstance(state, np.random.RandomState):
412414
return state
413415
elif state is None:
414416
return np.random
415417
else:
416418
raise ValueError(
417-
"random_state must be an integer, a numpy RandomState, or None"
419+
"random_state must be an integer, integer array like,"
420+
"a numpy RandomState, or None"
418421
)
419422

420423

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4692,7 +4692,7 @@ def sample(
46924692
If weights do not sum to 1, they will be normalized to sum to 1.
46934693
Missing values in the weights column will be treated as zero.
46944694
Infinite values not allowed.
4695-
random_state : int or numpy.random.RandomState, optional
4695+
random_state : int, integer array like or numpy.random.RandomState, optional
46964696
Seed for the random number generator (if int), or numpy RandomState
46974697
object.
46984698
axis : {0 or ‘index’, 1 or ‘columns’, None}, default None

pandas/tests/test_common.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,18 @@ def test_random_state():
5656
state2 = npr.RandomState(10)
5757
assert com.random_state(state2).uniform() == npr.RandomState(10).uniform()
5858

59+
# Check with array ike
60+
state3 = com.random_state([1, 5, 10])
61+
assert com.random_state(state3).uniform() == npr.RandomState([1, 5, 10]).uniform()
62+
5963
# check with no arg random state
6064
assert com.random_state() is np.random
6165

6266
# Error for floats or strings
63-
msg = "random_state must be an integer, a numpy RandomState, or None"
67+
msg = (
68+
"random_state must be an integer, integer array like,"
69+
"a numpy RandomState, or None"
70+
)
6471
with pytest.raises(ValueError, match=msg):
6572
com.random_state("test")
6673

0 commit comments

Comments
 (0)