Skip to content

Commit 92f53e2

Browse files
Solving some test issues
1 parent 7df44de commit 92f53e2

File tree

2 files changed

+53
-43
lines changed

2 files changed

+53
-43
lines changed

tests/link/numba/test_elemwise.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_elemwise_speed(benchmark):
146146

147147

148148
@pytest.mark.parametrize(
149-
"v, new_order",
149+
"test_values, new_order",
150150
[
151151
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
152152
(
@@ -204,14 +204,15 @@ def test_elemwise_speed(benchmark):
204204
),
205205
],
206206
)
207-
def test_Dimshuffle(v, new_order):
207+
def test_Dimshuffle(test_values, new_order):
208+
v = next(iter(test_values.keys()))
208209
g = v.dimshuffle(new_order)
209210
g_fg = FunctionGraph(outputs=[g])
210211
compare_numba_and_py(
211212
g_fg,
212213
[
213-
i.tag.test_value
214-
for i in g_fg.inputs
214+
test_values[i]
215+
for i in test_values
215216
if not isinstance(i, SharedVariable | Constant)
216217
],
217218
)

tests/scan/test_basic.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,38 +2241,42 @@ def test_compute_test_value_grad():
22412241
"""
22422242
See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
22432243
"""
2244-
# WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
2244+
WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
22452245

2246-
with config.change_flags(exception_verbosity="high"):
2247-
W_flat = fvector(name="W")
2248-
W = W_flat.reshape((2, 2, 3))
2246+
with config.change_flags(compute_test_value="raise", exception_verbosity="high"):
2247+
with pytest.warns(FutureWarning):
2248+
W_flat = fvector(name="W")
2249+
W_flat.tag.test_value = WEIGHT
2250+
W = W_flat.reshape((2, 2, 3))
22492251

2250-
outputs_mi = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2252+
outputs_mi = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2253+
outputs_mi.tag.test_value = np.asarray(0, dtype="float32")
22512254

2252-
def loss_mi(mi, sum_mi, W):
2253-
outputs_ti = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2255+
def loss_mi(mi, sum_mi, W):
2256+
outputs_ti = pt.as_tensor_variable(np.asarray(0, dtype="float32"))
2257+
outputs_ti.tag.test_value = np.asarray(0, dtype="float32")
22542258

2255-
def loss_ti(ti, sum_ti, mi, W):
2256-
return W.sum().sum().sum() + sum_ti
2259+
def loss_ti(ti, sum_ti, mi, W):
2260+
return W.sum().sum().sum() + sum_ti
22572261

2258-
result_ti, _ = scan(
2259-
fn=loss_ti,
2260-
outputs_info=outputs_ti,
2261-
sequences=pt.arange(W.shape[1], dtype="int32"),
2262-
non_sequences=[mi, W],
2263-
)
2264-
lossmi = result_ti[-1]
2265-
return sum_mi + lossmi
2262+
result_ti, _ = scan(
2263+
fn=loss_ti,
2264+
outputs_info=outputs_ti,
2265+
sequences=pt.arange(W.shape[1], dtype="int32"),
2266+
non_sequences=[mi, W],
2267+
)
2268+
lossmi = result_ti[-1]
2269+
return sum_mi + lossmi
22662270

2267-
result_mi, _ = scan(
2268-
fn=loss_mi,
2269-
outputs_info=outputs_mi,
2270-
sequences=pt.arange(W.shape[0], dtype="int32"),
2271-
non_sequences=[W],
2272-
)
2271+
result_mi, _ = scan(
2272+
fn=loss_mi,
2273+
outputs_info=outputs_mi,
2274+
sequences=pt.arange(W.shape[0], dtype="int32"),
2275+
non_sequences=[W],
2276+
)
22732277

2274-
loss = result_mi[-1]
2275-
grad(loss, W_flat)
2278+
loss = result_mi[-1]
2279+
grad(loss, W_flat)
22762280

22772281

22782282
@pytest.mark.xfail(reason="NominalVariables don't support test values")
@@ -2281,23 +2285,28 @@ def test_compute_test_value_grad_cast():
22812285
22822286
See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
22832287
"""
2284-
h = matrix("h")
2285-
with pytest.warns(FutureWarning):
2286-
h.tag.test_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX)
2288+
with config.change_flags(compute_test_value="raise"):
2289+
with pytest.warns(FutureWarning):
2290+
h = matrix("h")
2291+
h.tag.test_value = np.array(
2292+
[[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX
2293+
)
22872294

2288-
w = shared(
2289-
np.random.default_rng(utt.fetch_seed()).random((4, 3)).astype(config.floatX),
2290-
name="w",
2291-
)
2295+
w = shared(
2296+
np.random.default_rng(utt.fetch_seed())
2297+
.random((4, 3))
2298+
.astype(config.floatX),
2299+
name="w",
2300+
)
22922301

2293-
outputs, _ = scan(
2294-
lambda i, h, w: (dot(h[i], w), i),
2295-
outputs_info=[None, 0],
2296-
non_sequences=[h, w],
2297-
n_steps=3,
2298-
)
2302+
outputs, _ = scan(
2303+
lambda i, h, w: (dot(h[i], w), i),
2304+
outputs_info=[None, 0],
2305+
non_sequences=[h, w],
2306+
n_steps=3,
2307+
)
22992308

2300-
grad(outputs[0].sum(), w)
2309+
grad(outputs[0].sum(), w)
23012310

23022311

23032312
def test_constant_folding_n_steps():

0 commit comments

Comments
 (0)