@@ -2241,38 +2241,42 @@ def test_compute_test_value_grad():
2241
2241
"""
2242
2242
See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
2243
2243
"""
2244
- # WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
2244
+ WEIGHT = np .array ([1 , 2 , 1 , 3 , 4 , 1 , 5 , 6 , 1 , 7 , 8 , 1 ], dtype = "float32" )
2245
2245
2246
- with config .change_flags (exception_verbosity = "high" ):
2247
- W_flat = fvector (name = "W" )
2248
- W = W_flat .reshape ((2 , 2 , 3 ))
2246
+ with config .change_flags (compute_test_value = "raise" , exception_verbosity = "high" ):
2247
+ with pytest .warns (FutureWarning ):
2248
+ W_flat = fvector (name = "W" )
2249
+ W_flat .tag .test_value = WEIGHT
2250
+ W = W_flat .reshape ((2 , 2 , 3 ))
2249
2251
2250
- outputs_mi = pt .as_tensor_variable (np .asarray (0 , dtype = "float32" ))
2252
+ outputs_mi = pt .as_tensor_variable (np .asarray (0 , dtype = "float32" ))
2253
+ outputs_mi .tag .test_value = np .asarray (0 , dtype = "float32" )
2251
2254
2252
- def loss_mi (mi , sum_mi , W ):
2253
- outputs_ti = pt .as_tensor_variable (np .asarray (0 , dtype = "float32" ))
2255
+ def loss_mi (mi , sum_mi , W ):
2256
+ outputs_ti = pt .as_tensor_variable (np .asarray (0 , dtype = "float32" ))
2257
+ outputs_ti .tag .test_value = np .asarray (0 , dtype = "float32" )
2254
2258
2255
- def loss_ti (ti , sum_ti , mi , W ):
2256
- return W .sum ().sum ().sum () + sum_ti
2259
+ def loss_ti (ti , sum_ti , mi , W ):
2260
+ return W .sum ().sum ().sum () + sum_ti
2257
2261
2258
- result_ti , _ = scan (
2259
- fn = loss_ti ,
2260
- outputs_info = outputs_ti ,
2261
- sequences = pt .arange (W .shape [1 ], dtype = "int32" ),
2262
- non_sequences = [mi , W ],
2263
- )
2264
- lossmi = result_ti [- 1 ]
2265
- return sum_mi + lossmi
2262
+ result_ti , _ = scan (
2263
+ fn = loss_ti ,
2264
+ outputs_info = outputs_ti ,
2265
+ sequences = pt .arange (W .shape [1 ], dtype = "int32" ),
2266
+ non_sequences = [mi , W ],
2267
+ )
2268
+ lossmi = result_ti [- 1 ]
2269
+ return sum_mi + lossmi
2266
2270
2267
- result_mi , _ = scan (
2268
- fn = loss_mi ,
2269
- outputs_info = outputs_mi ,
2270
- sequences = pt .arange (W .shape [0 ], dtype = "int32" ),
2271
- non_sequences = [W ],
2272
- )
2271
+ result_mi , _ = scan (
2272
+ fn = loss_mi ,
2273
+ outputs_info = outputs_mi ,
2274
+ sequences = pt .arange (W .shape [0 ], dtype = "int32" ),
2275
+ non_sequences = [W ],
2276
+ )
2273
2277
2274
- loss = result_mi [- 1 ]
2275
- grad (loss , W_flat )
2278
+ loss = result_mi [- 1 ]
2279
+ grad (loss , W_flat )
2276
2280
2277
2281
2278
2282
@pytest .mark .xfail (reason = "NominalVariables don't support test values" )
@@ -2281,23 +2285,28 @@ def test_compute_test_value_grad_cast():
2281
2285
2282
2286
See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
2283
2287
"""
2284
- h = matrix ("h" )
2285
- with pytest .warns (FutureWarning ):
2286
- h .tag .test_value = np .array ([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = config .floatX )
2288
+ with config .change_flags (compute_test_value = "raise" ):
2289
+ with pytest .warns (FutureWarning ):
2290
+ h = matrix ("h" )
2291
+ h .tag .test_value = np .array (
2292
+ [[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = config .floatX
2293
+ )
2287
2294
2288
- w = shared (
2289
- np .random .default_rng (utt .fetch_seed ()).random ((4 , 3 )).astype (config .floatX ),
2290
- name = "w" ,
2291
- )
2295
+ w = shared (
2296
+ np .random .default_rng (utt .fetch_seed ())
2297
+ .random ((4 , 3 ))
2298
+ .astype (config .floatX ),
2299
+ name = "w" ,
2300
+ )
2292
2301
2293
- outputs , _ = scan (
2294
- lambda i , h , w : (dot (h [i ], w ), i ),
2295
- outputs_info = [None , 0 ],
2296
- non_sequences = [h , w ],
2297
- n_steps = 3 ,
2298
- )
2302
+ outputs , _ = scan (
2303
+ lambda i , h , w : (dot (h [i ], w ), i ),
2304
+ outputs_info = [None , 0 ],
2305
+ non_sequences = [h , w ],
2306
+ n_steps = 3 ,
2307
+ )
2299
2308
2300
- grad (outputs [0 ].sum (), w )
2309
+ grad (outputs [0 ].sum (), w )
2301
2310
2302
2311
2303
2312
def test_constant_folding_n_steps ():
0 commit comments