Skip to content

Commit d0221cb

Browse files
authored
CLN: Use np.random.RandomState instead of tm.RNGContext (pandas-dev#50915)
1 parent e4df678 commit d0221cb

File tree

9 files changed

+79
-136
lines changed

9 files changed

+79
-136
lines changed

pandas/_testing/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@
100100
get_obj,
101101
)
102102
from pandas._testing.contexts import (
103-
RNGContext,
104103
decompress_file,
105104
ensure_clean,
106105
ensure_safe_environment_variables,
@@ -1135,7 +1134,6 @@ def shares_memory(left, right) -> bool:
11351134
"raise_assert_detail",
11361135
"rands",
11371136
"reset_display_options",
1138-
"RNGContext",
11391137
"raises_chained_assignment_error",
11401138
"round_trip_localpath",
11411139
"round_trip_pathlib",

pandas/_testing/contexts.py

-37
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
import os
55
from pathlib import Path
66
import tempfile
7-
from types import TracebackType
87
from typing import (
98
IO,
109
Any,
1110
Generator,
1211
)
1312
import uuid
1413

15-
import numpy as np
16-
1714
from pandas.compat import PYPY
1815
from pandas.errors import ChainedAssignmentError
1916

@@ -198,40 +195,6 @@ def use_numexpr(use, min_elements=None) -> Generator[None, None, None]:
198195
set_option("compute.use_numexpr", olduse)
199196

200197

201-
class RNGContext:
202-
"""
203-
Context manager to set the numpy random number generator speed. Returns
204-
to the original value upon exiting the context manager.
205-
206-
Parameters
207-
----------
208-
seed : int
209-
Seed for numpy.random.seed
210-
211-
Examples
212-
--------
213-
with RNGContext(42):
214-
np.random.randn()
215-
"""
216-
217-
def __init__(self, seed) -> None:
218-
self.seed = seed
219-
220-
def __enter__(self) -> None:
221-
222-
self.start_state = np.random.get_state()
223-
np.random.seed(self.seed)
224-
225-
def __exit__(
226-
self,
227-
exc_type: type[BaseException] | None,
228-
exc_value: BaseException | None,
229-
traceback: TracebackType | None,
230-
) -> None:
231-
232-
np.random.set_state(self.start_state)
233-
234-
235198
def raises_chained_assignment_error():
236199

237200
if PYPY:

pandas/tests/plotting/conftest.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,30 @@
55
DataFrame,
66
to_datetime,
77
)
8-
import pandas._testing as tm
98

109

1110
@pytest.fixture
1211
def hist_df():
1312
n = 100
14-
with tm.RNGContext(42):
15-
gender = np.random.choice(["Male", "Female"], size=n)
16-
classroom = np.random.choice(["A", "B", "C"], size=n)
13+
np_random = np.random.RandomState(42)
14+
gender = np_random.choice(["Male", "Female"], size=n)
15+
classroom = np_random.choice(["A", "B", "C"], size=n)
1716

18-
hist_df = DataFrame(
19-
{
20-
"gender": gender,
21-
"classroom": classroom,
22-
"height": np.random.normal(66, 4, size=n),
23-
"weight": np.random.normal(161, 32, size=n),
24-
"category": np.random.randint(4, size=n),
25-
"datetime": to_datetime(
26-
np.random.randint(
27-
812419200000000000,
28-
819331200000000000,
29-
size=n,
30-
dtype=np.int64,
31-
)
32-
),
33-
}
34-
)
17+
hist_df = DataFrame(
18+
{
19+
"gender": gender,
20+
"classroom": classroom,
21+
"height": np.random.normal(66, 4, size=n),
22+
"weight": np.random.normal(161, 32, size=n),
23+
"category": np.random.randint(4, size=n),
24+
"datetime": to_datetime(
25+
np.random.randint(
26+
812419200000000000,
27+
819331200000000000,
28+
size=n,
29+
dtype=np.int64,
30+
)
31+
),
32+
}
33+
)
3534
return hist_df

pandas/tests/plotting/frame/test_frame.py

+52-54
Original file line numberDiff line numberDiff line change
@@ -366,51 +366,51 @@ def _compare_stacked_y_cood(self, normal_lines, stacked_lines):
366366

367367
@pytest.mark.parametrize("kind", ["line", "area"])
368368
def test_line_area_stacked(self, kind):
369-
with tm.RNGContext(42):
370-
df = DataFrame(np.random.rand(6, 4), columns=["w", "x", "y", "z"])
371-
neg_df = -df
372-
# each column has either positive or negative value
373-
sep_df = DataFrame(
374-
{
375-
"w": np.random.rand(6),
376-
"x": np.random.rand(6),
377-
"y": -np.random.rand(6),
378-
"z": -np.random.rand(6),
379-
}
380-
)
381-
# each column has positive-negative mixed value
382-
mixed_df = DataFrame(
383-
np.random.randn(6, 4),
384-
index=list(string.ascii_letters[:6]),
385-
columns=["w", "x", "y", "z"],
386-
)
369+
np_random = np.random.RandomState(42)
370+
df = DataFrame(np_random.rand(6, 4), columns=["w", "x", "y", "z"])
371+
neg_df = -df
372+
# each column has either positive or negative value
373+
sep_df = DataFrame(
374+
{
375+
"w": np_random.rand(6),
376+
"x": np_random.rand(6),
377+
"y": -np_random.rand(6),
378+
"z": -np_random.rand(6),
379+
}
380+
)
381+
# each column has positive-negative mixed value
382+
mixed_df = DataFrame(
383+
np_random.randn(6, 4),
384+
index=list(string.ascii_letters[:6]),
385+
columns=["w", "x", "y", "z"],
386+
)
387387

388-
ax1 = _check_plot_works(df.plot, kind=kind, stacked=False)
389-
ax2 = _check_plot_works(df.plot, kind=kind, stacked=True)
390-
self._compare_stacked_y_cood(ax1.lines, ax2.lines)
388+
ax1 = _check_plot_works(df.plot, kind=kind, stacked=False)
389+
ax2 = _check_plot_works(df.plot, kind=kind, stacked=True)
390+
self._compare_stacked_y_cood(ax1.lines, ax2.lines)
391391

392-
ax1 = _check_plot_works(neg_df.plot, kind=kind, stacked=False)
393-
ax2 = _check_plot_works(neg_df.plot, kind=kind, stacked=True)
394-
self._compare_stacked_y_cood(ax1.lines, ax2.lines)
392+
ax1 = _check_plot_works(neg_df.plot, kind=kind, stacked=False)
393+
ax2 = _check_plot_works(neg_df.plot, kind=kind, stacked=True)
394+
self._compare_stacked_y_cood(ax1.lines, ax2.lines)
395395

396-
ax1 = _check_plot_works(sep_df.plot, kind=kind, stacked=False)
397-
ax2 = _check_plot_works(sep_df.plot, kind=kind, stacked=True)
398-
self._compare_stacked_y_cood(ax1.lines[:2], ax2.lines[:2])
399-
self._compare_stacked_y_cood(ax1.lines[2:], ax2.lines[2:])
396+
ax1 = _check_plot_works(sep_df.plot, kind=kind, stacked=False)
397+
ax2 = _check_plot_works(sep_df.plot, kind=kind, stacked=True)
398+
self._compare_stacked_y_cood(ax1.lines[:2], ax2.lines[:2])
399+
self._compare_stacked_y_cood(ax1.lines[2:], ax2.lines[2:])
400400

401-
_check_plot_works(mixed_df.plot, stacked=False)
402-
msg = (
403-
"When stacked is True, each column must be either all positive or "
404-
"all negative. Column 'w' contains both positive and negative "
405-
"values"
406-
)
407-
with pytest.raises(ValueError, match=msg):
408-
mixed_df.plot(stacked=True)
401+
_check_plot_works(mixed_df.plot, stacked=False)
402+
msg = (
403+
"When stacked is True, each column must be either all positive or "
404+
"all negative. Column 'w' contains both positive and negative "
405+
"values"
406+
)
407+
with pytest.raises(ValueError, match=msg):
408+
mixed_df.plot(stacked=True)
409409

410-
# Use an index with strictly positive values, preventing
411-
# matplotlib from warning about ignoring xlim
412-
df2 = df.set_index(df.index + 1)
413-
_check_plot_works(df2.plot, kind=kind, logx=True, stacked=True)
410+
# Use an index with strictly positive values, preventing
411+
# matplotlib from warning about ignoring xlim
412+
df2 = df.set_index(df.index + 1)
413+
_check_plot_works(df2.plot, kind=kind, logx=True, stacked=True)
414414

415415
def test_line_area_nan_df(self):
416416
values1 = [1, 2, np.nan, 3]
@@ -1237,20 +1237,18 @@ def test_all_invalid_plot_data(self):
12371237
df.plot(kind=kind)
12381238

12391239
def test_partially_invalid_plot_data(self):
1240-
with tm.RNGContext(42):
1241-
df = DataFrame(np.random.randn(10, 2), dtype=object)
1242-
df[np.random.rand(df.shape[0]) > 0.5] = "a"
1243-
for kind in plotting.PlotAccessor._common_kinds:
1244-
msg = "no numeric data to plot"
1245-
with pytest.raises(TypeError, match=msg):
1246-
df.plot(kind=kind)
1247-
1248-
with tm.RNGContext(42):
1249-
# area plot doesn't support positive/negative mixed data
1250-
df = DataFrame(np.random.rand(10, 2), dtype=object)
1251-
df[np.random.rand(df.shape[0]) > 0.5] = "a"
1252-
with pytest.raises(TypeError, match="no numeric data to plot"):
1253-
df.plot(kind="area")
1240+
df = DataFrame(np.random.RandomState(42).randn(10, 2), dtype=object)
1241+
df[np.random.rand(df.shape[0]) > 0.5] = "a"
1242+
for kind in plotting.PlotAccessor._common_kinds:
1243+
msg = "no numeric data to plot"
1244+
with pytest.raises(TypeError, match=msg):
1245+
df.plot(kind=kind)
1246+
1247+
# area plot doesn't support positive/negative mixed data
1248+
df = DataFrame(np.random.RandomState(42).rand(10, 2), dtype=object)
1249+
df[np.random.rand(df.shape[0]) > 0.5] = "a"
1250+
with pytest.raises(TypeError, match="no numeric data to plot"):
1251+
df.plot(kind="area")
12541252

12551253
def test_invalid_kind(self):
12561254
df = DataFrame(np.random.randn(10, 2))

pandas/tests/plotting/test_boxplot_method.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def test_grouped_plot_fignums(self):
362362
n = 10
363363
weight = Series(np.random.normal(166, 20, size=n))
364364
height = Series(np.random.normal(60, 10, size=n))
365-
with tm.RNGContext(42):
366-
gender = np.random.choice(["male", "female"], size=n)
365+
gender = np.random.RandomState(42).choice(["male", "female"], size=n)
367366
df = DataFrame({"height": height, "weight": weight, "gender": gender})
368367
gb = df.groupby("gender")
369368

pandas/tests/plotting/test_groupby.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def test_series_groupby_plotting_nominally_works(self):
2121
n = 10
2222
weight = Series(np.random.normal(166, 20, size=n))
2323
height = Series(np.random.normal(60, 10, size=n))
24-
with tm.RNGContext(42):
25-
gender = np.random.choice(["male", "female"], size=n)
24+
gender = np.random.RandomState(42).choice(["male", "female"], size=n)
2625

2726
weight.groupby(gender).plot()
2827
tm.close()

pandas/tests/plotting/test_hist_method.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,9 @@ def test_hist_df_kwargs(self):
510510

511511
def test_hist_df_with_nonnumerics(self):
512512
# GH 9853
513-
with tm.RNGContext(1):
514-
df = DataFrame(np.random.randn(10, 4), columns=["A", "B", "C", "D"])
513+
df = DataFrame(
514+
np.random.RandomState(42).randn(10, 4), columns=["A", "B", "C", "D"]
515+
)
515516
df["E"] = ["x", "y"] * 5
516517
_, ax = self.plt.subplots()
517518
ax = df.plot.hist(bins=5, ax=ax)
@@ -665,8 +666,7 @@ def test_grouped_hist_legacy2(self):
665666
n = 10
666667
weight = Series(np.random.normal(166, 20, size=n))
667668
height = Series(np.random.normal(60, 10, size=n))
668-
with tm.RNGContext(42):
669-
gender_int = np.random.choice([0, 1], size=n)
669+
gender_int = np.random.RandomState(42).choice([0, 1], size=n)
670670
df_int = DataFrame({"height": height, "weight": weight, "gender": gender_int})
671671
gb = df_int.groupby("gender")
672672
axes = gb.hist()

pandas/tests/plotting/test_misc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ def test_scatter_matrix_axis(self, pass_axis):
102102
if pass_axis:
103103
_, ax = self.plt.subplots(3, 3)
104104

105-
with tm.RNGContext(42):
106-
df = DataFrame(np.random.randn(100, 3))
105+
df = DataFrame(np.random.RandomState(42).randn(100, 3))
107106

108107
# we are plotting multiples on a sub-plot
109108
with tm.assert_produces_warning(UserWarning, check_stacklevel=False):

pandas/tests/util/test_util.py

-12
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,6 @@ def test_datapath(datapath):
5858
assert result == expected
5959

6060

61-
def test_rng_context():
62-
import numpy as np
63-
64-
expected0 = 1.764052345967664
65-
expected1 = 1.6243453636632417
66-
67-
with tm.RNGContext(0):
68-
with tm.RNGContext(1):
69-
assert np.random.randn() == expected1
70-
assert np.random.randn() == expected0
71-
72-
7361
def test_external_error_raised():
7462
with tm.external_error_raised(TypeError):
7563
raise TypeError("Should not check this error message, so it will pass")

0 commit comments

Comments
 (0)