Skip to content

Commit 069b58a

Browse files
michaelosthegebrandonwillard
authored andcommitted
Pass kwargs instead of list of tuples to change_flags
1 parent f7898c4 commit 069b58a

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

tests/tensor/nnet/test_abstract_conv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def test_basic(self):
295295

296296

297297
class TestAssertShape:
298-
@change_flags([("conv__assert_shape", True)])
298+
@change_flags(conv__assert_shape=True)
299299
def test_basic(self):
300300
x = tensor.tensor4()
301301
s1 = tensor.iscalar()
@@ -318,7 +318,7 @@ def test_basic(self):
318318
with pytest.raises(AssertionError):
319319
f(v, 7, 7)
320320

321-
@change_flags([("conv__assert_shape", True)])
321+
@change_flags(conv__assert_shape=True)
322322
def test_shape_check_conv2d(self):
323323
input = tensor.tensor4()
324324
filters = tensor.tensor4()
@@ -340,7 +340,7 @@ def test_shape_check_conv2d(self):
340340
np.zeros((7, 5, 2, 2), dtype="float32"),
341341
)
342342

343-
@change_flags([("conv__assert_shape", True)])
343+
@change_flags(conv__assert_shape=True)
344344
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
345345
def test_shape_check_conv3d(self):
346346
input = tensor.tensor5()
@@ -363,7 +363,7 @@ def test_shape_check_conv3d(self):
363363
np.zeros((7, 5, 2, 2, 2), dtype="float32"),
364364
)
365365

366-
@change_flags([("conv__assert_shape", True)])
366+
@change_flags(conv__assert_shape=True)
367367
def test_shape_check_conv2d_grad_wrt_inputs(self):
368368
output_grad = tensor.tensor4()
369369
filters = tensor.tensor4()
@@ -382,7 +382,7 @@ def test_shape_check_conv2d_grad_wrt_inputs(self):
382382
np.zeros((7, 6, 3, 3), dtype="float32"),
383383
)
384384

385-
@change_flags([("conv__assert_shape", True)])
385+
@change_flags(conv__assert_shape=True)
386386
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
387387
def test_shape_check_conv3d_grad_wrt_inputs(self):
388388
output_grad = tensor.tensor5()
@@ -402,7 +402,7 @@ def test_shape_check_conv3d_grad_wrt_inputs(self):
402402
np.zeros((7, 6, 3, 3, 3), dtype="float32"),
403403
)
404404

405-
@change_flags([("conv__assert_shape", True)])
405+
@change_flags(conv__assert_shape=True)
406406
def test_shape_check_conv2d_grad_wrt_weights(self):
407407
input = tensor.tensor4()
408408
output_grad = tensor.tensor4()
@@ -421,7 +421,7 @@ def test_shape_check_conv2d_grad_wrt_weights(self):
421421
np.zeros((3, 7, 5, 9), dtype="float32"),
422422
)
423423

424-
@change_flags([("conv__assert_shape", True)])
424+
@change_flags(conv__assert_shape=True)
425425
@pytest.mark.skipif(theano.config.cxx == "", reason="test needs cxx")
426426
def test_shape_check_conv3d_grad_wrt_weights(self):
427427
input = tensor.tensor5()

tests/tensor/nnet/test_nnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def test_optimize_xent_vector2(self):
874874
assert crossentropy_softmax_argmax_1hot_with_bias in ops
875875
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
876876

877-
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
877+
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
878878
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
879879
optdb.query(OPT_FAST_RUN).optimize(fgraph)
880880

@@ -911,7 +911,7 @@ def test_optimize_xent_vector3(self):
911911
assert crossentropy_softmax_argmax_1hot_with_bias in ops
912912
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
913913

914-
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
914+
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
915915
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
916916
optdb.query(OPT_FAST_RUN).optimize(fgraph)
917917

@@ -949,7 +949,7 @@ def test_optimize_xent_vector4(self):
949949
assert crossentropy_softmax_argmax_1hot_with_bias in ops
950950
assert not [1 for o in ops if isinstance(o, tt.AdvancedSubtensor)]
951951

952-
with theano.change_flags([("warn__sum_div_dimshuffle_bug", False)]):
952+
with theano.change_flags(warn__sum_div_dimshuffle_bug=False):
953953
fgraph = gof.FunctionGraph([x, b, y], [tt.grad(expr, x)])
954954
optdb.query(OPT_FAST_RUN).optimize(fgraph)
955955

0 commit comments

Comments
 (0)