Skip to content

Commit 39c1147

Browse files
kc611brandonwillard
authored andcommitted
Refactor remaining tests to use NumPy Generator
1 parent 1ff4b9d commit 39c1147

File tree

19 files changed

+334
-276
lines changed

19 files changed

+334
-276
lines changed

tests/compile/function/test_pfunc.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ def test_doc(self):
6464
def test_shared(self):
6565

6666
# CHECK: two functions (f1 and f2) can share w
67-
w = shared(np.random.rand(2, 2), "w")
67+
w = shared(np.random.random((2, 2)), "w")
6868
wval = w.get_value(borrow=False)
6969

7070
x = dmatrix()
7171
out1 = w + x
7272
out2 = w * x
7373
f1 = pfunc([x], [out1])
7474
f2 = pfunc([x], [out2])
75-
xval = np.random.rand(2, 2)
75+
xval = np.random.random((2, 2))
7676
assert np.all(f1(xval) == xval + wval)
7777
assert np.all(f2(xval) == xval * wval)
7878

@@ -89,7 +89,7 @@ def test_shared(self):
8989

9090
def test_no_shared_as_input(self):
9191
# Test that shared variables cannot be used as function inputs.
92-
w_init = np.random.rand(2, 2)
92+
w_init = np.random.random((2, 2))
9393
w = shared(w_init.copy(), "w")
9494
with pytest.raises(
9595
TypeError, match=r"^Cannot use a shared variable \(w\) as explicit input"
@@ -100,8 +100,8 @@ def test_default_container(self):
100100
# Ensure it is possible to (implicitly) use a shared variable in a
101101
# function, as a 'state' that can be updated at will.
102102

103-
rng = np.random.RandomState(1827)
104-
w_init = rng.rand(5)
103+
rng = np.random.default_rng(1827)
104+
w_init = rng.random((5))
105105
w = shared(w_init.copy(), "w")
106106
reg = aet_sum(w * w)
107107
f = pfunc([], reg)
@@ -127,8 +127,8 @@ def test_param_strict(self):
127127
out = a + b
128128

129129
f = pfunc([In(a, strict=False)], [out])
130-
# works, rand generates float64 by default
131-
f(np.random.rand(8))
130+
# works, random( generates float64 by default
131+
f(np.random.random((8)))
132132
# works, casting is allowed
133133
f(np.array([1, 2, 3, 4], dtype="int32"))
134134

@@ -145,14 +145,14 @@ def test_param_mutable(self):
145145

146146
# using mutable=True will let fip change the value in aval
147147
fip = pfunc([In(a, mutable=True)], [a_out], mode="FAST_RUN")
148-
aval = np.random.rand(10)
148+
aval = np.random.random((10))
149149
aval2 = aval.copy()
150150
assert np.all(fip(aval) == (aval2 * 2))
151151
assert not np.all(aval == aval2)
152152

153153
# using mutable=False should leave the input untouched
154154
f = pfunc([In(a, mutable=False)], [a_out], mode="FAST_RUN")
155-
aval = np.random.rand(10)
155+
aval = np.random.random((10))
156156
aval2 = aval.copy()
157157
assert np.all(f(aval) == (aval2 * 2))
158158
assert np.all(aval == aval2)
@@ -375,7 +375,7 @@ def test_update(self):
375375

376376
def test_update_err_broadcast(self):
377377
# Test that broadcastable dimensions raise error
378-
data = np.random.rand(10, 10).astype("float32")
378+
data = np.random.random((10, 10)).astype("float32")
379379
output_var = shared(name="output", value=data)
380380

381381
# the update_var has type matrix, and the update expression

tests/compile/test_debugmode.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def perform(self, node, inp, out):
736736

737737
class TestPreallocatedOutput:
738738
def setup_method(self):
739-
self.rng = np.random.RandomState(seed=utt.fetch_seed())
739+
self.rng = np.random.default_rng(seed=utt.fetch_seed())
740740

741741
def test_f_contiguous(self):
742742
a = fmatrix("a")
@@ -745,8 +745,8 @@ def test_f_contiguous(self):
745745
# In this test, we do not want z to be an output of the graph.
746746
out = dot(z, np.eye(7))
747747

748-
a_val = self.rng.randn(7, 7).astype("float32")
749-
b_val = self.rng.randn(7, 7).astype("float32")
748+
a_val = self.rng.standard_normal((7, 7)).astype("float32")
749+
b_val = self.rng.standard_normal((7, 7)).astype("float32")
750750

751751
# Should work
752752
mode = DebugMode(check_preallocated_output=["c_contiguous"])
@@ -776,8 +776,8 @@ def test_f_contiguous_out(self):
776776
b = fmatrix("b")
777777
out = BrokenCImplementationAdd()(a, b)
778778

779-
a_val = self.rng.randn(7, 7).astype("float32")
780-
b_val = self.rng.randn(7, 7).astype("float32")
779+
a_val = self.rng.standard_normal((7, 7)).astype("float32")
780+
b_val = self.rng.standard_normal((7, 7)).astype("float32")
781781

782782
# Should work
783783
mode = DebugMode(check_preallocated_output=["c_contiguous"])
@@ -805,5 +805,5 @@ def test_output_broadcast_tensor(self):
805805
c, r = VecAsRowAndCol()(v)
806806
f = function([v], [c, r])
807807

808-
v_val = self.rng.randn(5).astype("float32")
808+
v_val = self.rng.standard_normal((5)).astype("float32")
809809
f(v_val)

tests/compile/test_misc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(
5454

5555

5656
def test_nnet():
57-
rng = np.random.RandomState(1827)
58-
data = rng.rand(10, 4)
57+
rng = np.random.default_rng(279)
58+
data = rng.random((10, 4))
5959
nnet = NNet(n_input=3, n_hidden=10)
6060
for epoch in range(3):
6161
mean_cost = 0
@@ -66,7 +66,8 @@ def test_nnet():
6666
mean_cost += cost
6767
mean_cost /= float(len(data))
6868
# print 'Mean cost at epoch %s: %s' % (epoch, mean_cost)
69-
assert abs(mean_cost - 0.20588975452) < 1e-6
69+
# Seed based test
70+
assert abs(mean_cost - 0.2301901) < 1e-6
7071
# Just call functions to make sure they do not crash.
7172
nnet.compute_output(input)
7273
nnet.output_from_hidden(np.ones(10))

tests/d3viz/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, nfeatures=100, noutputs=10, nhiddens=50, rng=None):
1111
if rng is None:
1212
rng = 0
1313
if isinstance(rng, int):
14-
rng = np.random.RandomState(rng)
14+
rng = np.random.default_rng(rng)
1515
self.rng = rng
1616
self.nfeatures = nfeatures
1717
self.noutputs = noutputs

tests/d3viz/test_d3viz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class TestD3Viz:
2121
def setup_method(self):
22-
self.rng = np.random.RandomState(0)
22+
self.rng = np.random.default_rng(0)
2323
self.data_dir = pt.join("data", "test_d3viz")
2424

2525
def check(self, f, reference=None, verbose=False):

tests/d3viz/test_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class TestPyDotFormatter:
1515
def setup_method(self):
16-
self.rng = np.random.RandomState(0)
16+
self.rng = np.random.default_rng(0)
1717

1818
def node_counts(self, graph):
1919
node_types = [node.get_attributes()["node_type"] for node in graph.get_nodes()]

tests/link/test_jax.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def test_jax_compile_ops():
218218

219219

220220
def test_jax_basic():
221+
rng = np.random.default_rng(28494)
222+
221223
x = matrix("x")
222224
y = matrix("y")
223225
b = vector("b")
@@ -259,15 +261,23 @@ def test_jax_basic():
259261
out_fg = FunctionGraph([x], [out])
260262
compare_jax_and_py(
261263
out_fg,
262-
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(config.floatX)],
264+
[
265+
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
266+
config.floatX
267+
)
268+
],
263269
)
264270

265271
# not sure why this isn't working yet with lower=False
266272
out = aet_slinalg.Cholesky(lower=False)(x)
267273
out_fg = FunctionGraph([x], [out])
268274
compare_jax_and_py(
269275
out_fg,
270-
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(config.floatX)],
276+
[
277+
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
278+
config.floatX
279+
)
280+
],
271281
)
272282

273283
out = aet_slinalg.solve(x, b)
@@ -294,7 +304,11 @@ def test_jax_basic():
294304
out_fg = FunctionGraph([x], [out])
295305
compare_jax_and_py(
296306
out_fg,
297-
[(np.eye(10) + np.random.randn(10, 10) * 0.01).astype(config.floatX)],
307+
[
308+
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
309+
config.floatX
310+
)
311+
],
298312
)
299313

300314

@@ -405,9 +419,9 @@ def test_jax_eye():
405419

406420

407421
def test_jax_basic_multiout():
422+
rng = np.random.default_rng(213234)
408423

409-
np.random.seed(213234)
410-
M = np.random.normal(size=(3, 3))
424+
M = rng.normal(size=(3, 3))
411425
X = M.dot(M.T)
412426

413427
x = matrix("x")
@@ -638,7 +652,9 @@ def test_jax_Subtensors_omni():
638652
reason="Omnistaging cannot be disabled",
639653
)
640654
def test_jax_IncSubtensor():
641-
x_np = np.random.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
655+
rng = np.random.default_rng(213234)
656+
657+
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
642658
x_aet = aet.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
643659

644660
# "Set" basic indices
@@ -661,7 +677,7 @@ def test_jax_IncSubtensor():
661677

662678
# "Set" advanced indices
663679
st_aet = aet.as_tensor_variable(
664-
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
680+
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
665681
)
666682
out_aet = aet_subtensor.set_subtensor(x_aet[np.r_[0, 2]], st_aet)
667683
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
@@ -707,7 +723,7 @@ def test_jax_IncSubtensor():
707723

708724
# "Increment" advanced indices
709725
st_aet = aet.as_tensor_variable(
710-
np.random.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
726+
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
711727
)
712728
out_aet = aet_subtensor.inc_subtensor(x_aet[np.r_[0, 2]], st_aet)
713729
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor1)
@@ -1202,6 +1218,7 @@ def rng_fn(cls, rng, size):
12021218
compare_jax_and_py(fgraph, [])
12031219

12041220

1221+
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
12051222
def test_RandomStream():
12061223
srng = RandomStream(seed=123)
12071224
out = srng.normal() - srng.normal()
@@ -1211,3 +1228,11 @@ def test_RandomStream():
12111228
jax_res_2 = fn()
12121229

12131230
assert np.array_equal(jax_res_1, jax_res_2)
1231+
1232+
1233+
@pytest.mark.xfail(reason="Generators not yet supported in JAX")
1234+
def test_random_generators():
1235+
rng = shared(np.random.default_rng(123))
1236+
out = normal(rng=rng)
1237+
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
1238+
compare_jax_and_py(fgraph, [])

0 commit comments

Comments
 (0)