Skip to content

Commit 1897c25

Browse files
Solving some test issues
1 parent 7df44de commit 1897c25

File tree

2 files changed

+23
-17
lines changed

2 files changed

+23
-17
lines changed

tests/link/numba/test_elemwise.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_elemwise_speed(benchmark):
146146

147147

148148
@pytest.mark.parametrize(
149-
"v, new_order",
149+
"test_values, new_order",
150150
[
151151
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
152152
(
@@ -204,14 +204,15 @@ def test_elemwise_speed(benchmark):
204204
),
205205
],
206206
)
207-
def test_Dimshuffle(v, new_order):
207+
def test_Dimshuffle(test_values, new_order):
208+
v = next(iter(test_values.keys()))
208209
g = v.dimshuffle(new_order)
209210
g_fg = FunctionGraph(outputs=[g])
210211
compare_numba_and_py(
211212
g_fg,
212213
[
213-
i.tag.test_value
214-
for i in g_fg.inputs
214+
test_values[i]
215+
for i in test_values
215216
if not isinstance(i, SharedVariable | Constant)
216217
],
217218
)

tests/scan/test_basic.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,16 +2241,19 @@ def test_compute_test_value_grad():
22412241
"""
22422242
See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
22432243
"""
2244-
# WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
2244+
WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
22452245

22462246
with config.change_flags(exception_verbosity="high"):
22472247
W_flat = fvector(name="W")
2248+
W_flat.tag.test_value = WEIGHT
22482249
W = W_flat.reshape((2, 2, 3))
22492250

22502251
outputs_mi = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2252+
outputs_mi.tag.test_value = np.asarray(0, dtype="float32")
22512253

22522254
def loss_mi(mi, sum_mi, W):
22532255
outputs_ti = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2256+
outputs_ti.tag.test_value = np.asarray(0, dtype="float32")
22542257

22552258
def loss_ti(ti, sum_ti, mi, W):
22562259
return W.sum().sum().sum() + sum_ti
@@ -2281,23 +2284,25 @@ def test_compute_test_value_grad_cast():
22812284
22822285
See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
22832286
"""
2284-
h = matrix("h")
22852287
with pytest.warns(FutureWarning):
2288+
h = matrix("h")
22862289
h.tag.test_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX)
22872290

2288-
w = shared(
2289-
np.random.default_rng(utt.fetch_seed()).random((4, 3)).astype(config.floatX),
2290-
name="w",
2291-
)
2291+
w = shared(
2292+
np.random.default_rng(utt.fetch_seed())
2293+
.random((4, 3))
2294+
.astype(config.floatX),
2295+
name="w",
2296+
)
22922297

2293-
outputs, _ = scan(
2294-
lambda i, h, w: (dot(h[i], w), i),
2295-
outputs_info=[None, 0],
2296-
non_sequences=[h, w],
2297-
n_steps=3,
2298-
)
2298+
outputs, _ = scan(
2299+
lambda i, h, w: (dot(h[i], w), i),
2300+
outputs_info=[None, 0],
2301+
non_sequences=[h, w],
2302+
n_steps=3,
2303+
)
22992304

2300-
grad(outputs[0].sum(), w)
2305+
grad(outputs[0].sum(), w)
23012306

23022307

23032308
def test_constant_folding_n_steps():

0 commit comments

Comments
 (0)