@@ -295,7 +295,7 @@ def test_basic(self):
295
295
296
296
297
297
class TestAssertShape :
298
- @change_flags ([( " conv__assert_shape" , True )] )
298
+ @change_flags (conv__assert_shape = True )
299
299
def test_basic (self ):
300
300
x = tensor .tensor4 ()
301
301
s1 = tensor .iscalar ()
@@ -318,7 +318,7 @@ def test_basic(self):
318
318
with pytest .raises (AssertionError ):
319
319
f (v , 7 , 7 )
320
320
321
- @change_flags ([( " conv__assert_shape" , True )] )
321
+ @change_flags (conv__assert_shape = True )
322
322
def test_shape_check_conv2d (self ):
323
323
input = tensor .tensor4 ()
324
324
filters = tensor .tensor4 ()
@@ -340,7 +340,7 @@ def test_shape_check_conv2d(self):
340
340
np .zeros ((7 , 5 , 2 , 2 ), dtype = "float32" ),
341
341
)
342
342
343
- @change_flags ([( " conv__assert_shape" , True )] )
343
+ @change_flags (conv__assert_shape = True )
344
344
@pytest .mark .skipif (theano .config .cxx == "" , reason = "test needs cxx" )
345
345
def test_shape_check_conv3d (self ):
346
346
input = tensor .tensor5 ()
@@ -363,7 +363,7 @@ def test_shape_check_conv3d(self):
363
363
np .zeros ((7 , 5 , 2 , 2 , 2 ), dtype = "float32" ),
364
364
)
365
365
366
- @change_flags ([( " conv__assert_shape" , True )] )
366
+ @change_flags (conv__assert_shape = True )
367
367
def test_shape_check_conv2d_grad_wrt_inputs (self ):
368
368
output_grad = tensor .tensor4 ()
369
369
filters = tensor .tensor4 ()
@@ -382,7 +382,7 @@ def test_shape_check_conv2d_grad_wrt_inputs(self):
382
382
np .zeros ((7 , 6 , 3 , 3 ), dtype = "float32" ),
383
383
)
384
384
385
- @change_flags ([( " conv__assert_shape" , True )] )
385
+ @change_flags (conv__assert_shape = True )
386
386
@pytest .mark .skipif (theano .config .cxx == "" , reason = "test needs cxx" )
387
387
def test_shape_check_conv3d_grad_wrt_inputs (self ):
388
388
output_grad = tensor .tensor5 ()
@@ -402,7 +402,7 @@ def test_shape_check_conv3d_grad_wrt_inputs(self):
402
402
np .zeros ((7 , 6 , 3 , 3 , 3 ), dtype = "float32" ),
403
403
)
404
404
405
- @change_flags ([( " conv__assert_shape" , True )] )
405
+ @change_flags (conv__assert_shape = True )
406
406
def test_shape_check_conv2d_grad_wrt_weights (self ):
407
407
input = tensor .tensor4 ()
408
408
output_grad = tensor .tensor4 ()
@@ -421,7 +421,7 @@ def test_shape_check_conv2d_grad_wrt_weights(self):
421
421
np .zeros ((3 , 7 , 5 , 9 ), dtype = "float32" ),
422
422
)
423
423
424
- @change_flags ([( " conv__assert_shape" , True )] )
424
+ @change_flags (conv__assert_shape = True )
425
425
@pytest .mark .skipif (theano .config .cxx == "" , reason = "test needs cxx" )
426
426
def test_shape_check_conv3d_grad_wrt_weights (self ):
427
427
input = tensor .tensor5 ()
0 commit comments