Skip to content

Commit 9f7a1b6

Browse files
Merge pull request #110 from brandonwillard/enforce-sane-test-values
Refactor test value framework so that test value validation is performed up-front.
2 parents 53486aa + ea44b16 commit 9f7a1b6

22 files changed

+192
-464
lines changed

doc/library/misc/pkl_utils.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515

1616
.. autoclass:: theano.misc.pkl_utils.StripPickler
1717

18-
.. autoclass:: theano.misc.pkl_utils.CompatUnpickler
19-
2018
.. seealso::
2119

2220
:ref:`tutorial_loadsave`
23-
24-

tests/gof/test_compute_test_value.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,16 @@ def test_constant(self):
167167

168168
@theano.change_flags(compute_test_value="raise")
169169
def test_incorrect_type(self):
170-
x = tt.fmatrix("x")
171-
# Incorrect dtype (float64) for test_value
172-
x.tag.test_value = np.random.rand(3, 4)
173-
y = tt.dmatrix("y")
174-
y.tag.test_value = np.random.rand(4, 5)
175170

171+
x = tt.vector("x")
176172
with pytest.raises(TypeError):
177-
tt.dot(x, y)
173+
# Incorrect shape for test value
174+
x.tag.test_value = np.empty((2, 2))
175+
176+
x = tt.fmatrix("x")
177+
with pytest.raises(TypeError):
178+
# Incorrect dtype (float64) for test value
179+
x.tag.test_value = np.random.rand(3, 4)
178180

179181
@theano.change_flags(compute_test_value="raise")
180182
def test_overided_function(self):

tests/gof/test_fg.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
import os
21
import pickle
32

4-
import pytest
5-
6-
import theano
7-
from theano.compat import PY3
8-
from theano.gof.fg import FunctionGraph
93
from theano import tensor as tt
4+
from theano.gof.fg import FunctionGraph
105

116

127
class TestFunctionGraph:
@@ -16,24 +11,3 @@ def test_pickle(self):
1611

1712
s = pickle.dumps(func)
1813
pickle.loads(s)
19-
20-
@pytest.mark.skipif(
21-
not theano.config.cxx, reason="G++ not available, so we need to skip this test."
22-
)
23-
@pytest.mark.slow
24-
def test_node_outputs_not_used(self):
25-
# In the past, we where removing some not used variable from
26-
# fgraph.variables event if the apply had other output used in
27-
# the graph. This caused a crash.
28-
# This test run the pickle that reproduce this case.
29-
with open(
30-
os.path.join(os.path.dirname(__file__), "test_fg_old_crash.pkl"), "rb"
31-
) as f:
32-
from theano.misc.pkl_utils import CompatUnpickler
33-
34-
if PY3:
35-
u = CompatUnpickler(f, encoding="latin1")
36-
else:
37-
u = CompatUnpickler(f)
38-
d = u.load()
39-
f = theano.function(**d)

tests/gof/test_fg_old_crash.pkl

-168 KB
Binary file not shown.

tests/gof/test_op.py

Lines changed: 32 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import numpy as np
22
import pytest
3-
43
import theano
54
import theano.gof.op as op
5+
import theano.tensor as tt
6+
67
from six import string_types
7-
from theano.gof.type import Type, Generic
8+
from theano import scalar, shared
9+
from theano.configparser import change_flags
810
from theano.gof.graph import Apply, Variable
9-
import theano.tensor as T
10-
from theano import scalar
11-
from theano import shared
11+
from theano.gof.type import Generic, Type
1212

1313
config = theano.config
1414
Op = op.Op
@@ -238,15 +238,15 @@ class DoubleOp(Op):
238238

239239
__props__ = ()
240240

241-
itypes = [T.dmatrix]
242-
otypes = [T.dmatrix]
241+
itypes = [tt.dmatrix]
242+
otypes = [tt.dmatrix]
243243

244244
def perform(self, node, inputs, outputs):
245245
inp = inputs[0]
246246
output = outputs[0]
247247
output[0] = inp * 2
248248

249-
x_input = T.dmatrix("x_input")
249+
x_input = tt.dmatrix("x_input")
250250
f = theano.function([x_input], DoubleOp()(x_input))
251251
inp = np.random.rand(5, 4)
252252
out = f(inp)
@@ -255,17 +255,17 @@ def perform(self, node, inputs, outputs):
255255

256256
def test_test_value_python_objects():
257257
for x in ([0, 1, 2], 0, 0.5, 1):
258-
assert (op.get_test_value(x) == x).all()
258+
assert np.all(op.get_test_value(x) == x)
259259

260260

261261
def test_test_value_ndarray():
262262
x = np.zeros((5, 5))
263263
v = op.get_test_value(x)
264-
assert (v == x).all()
264+
assert np.all(v == x)
265265

266266

267267
def test_test_value_constant():
268-
x = T.as_tensor_variable(np.zeros((5, 5)))
268+
x = tt.as_tensor_variable(np.zeros((5, 5)))
269269
v = op.get_test_value(x)
270270

271271
assert np.all(v == np.zeros((5, 5)))
@@ -278,62 +278,37 @@ def test_test_value_shared():
278278
assert np.all(v == np.zeros((5, 5)))
279279

280280

281+
@change_flags(compute_test_value="raise")
281282
def test_test_value_op():
282-
try:
283-
prev_value = config.compute_test_value
284-
config.compute_test_value = "raise"
285-
x = T.log(np.ones((5, 5)))
286-
v = op.get_test_value(x)
287-
288-
assert np.allclose(v, np.zeros((5, 5)))
289-
finally:
290-
config.compute_test_value = prev_value
291-
292283

293-
def test_get_debug_values_no_debugger():
294-
"get_debug_values should return [] when debugger is off"
284+
x = tt.log(np.ones((5, 5)))
285+
v = op.get_test_value(x)
295286

296-
prev_value = config.compute_test_value
297-
try:
298-
config.compute_test_value = "off"
287+
assert np.allclose(v, np.zeros((5, 5)))
299288

300-
x = T.vector()
301289

302-
for x_val in op.get_debug_values(x):
303-
assert False
290+
@change_flags(compute_test_value="off")
291+
def test_get_debug_values_no_debugger():
292+
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
304293

305-
finally:
306-
config.compute_test_value = prev_value
294+
x = tt.vector()
295+
assert op.get_debug_values(x) == []
307296

308297

298+
@change_flags(compute_test_value="ignore")
309299
def test_get_det_debug_values_ignore():
310-
# get_debug_values should return [] when debugger is ignore
311-
# and some values are missing
300+
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
312301

313-
prev_value = config.compute_test_value
314-
try:
315-
config.compute_test_value = "ignore"
316-
317-
x = T.vector()
318-
319-
for x_val in op.get_debug_values(x):
320-
assert False
321-
322-
finally:
323-
config.compute_test_value = prev_value
302+
x = tt.vector()
303+
assert op.get_debug_values(x) == []
324304

325305

326306
def test_get_debug_values_success():
327-
# tests that get_debug_value returns values when available
328-
# (and the debugger is on)
307+
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
329308

330-
prev_value = config.compute_test_value
331309
for mode in ["ignore", "warn", "raise"]:
332-
333-
try:
334-
config.compute_test_value = mode
335-
336-
x = T.vector()
310+
with change_flags(compute_test_value=mode):
311+
x = tt.vector()
337312
x.tag.test_value = np.zeros((4,), dtype=config.floatX)
338313
y = np.zeros((5, 5))
339314

@@ -348,54 +323,11 @@ def test_get_debug_values_success():
348323

349324
assert iters == 1
350325

351-
finally:
352-
config.compute_test_value = prev_value
353-
354326

327+
@change_flags(compute_test_value="raise")
355328
def test_get_debug_values_exc():
356-
# tests that get_debug_value raises an exception when
357-
# debugger is set to raise and a value is missing
358-
359-
prev_value = config.compute_test_value
360-
try:
361-
config.compute_test_value = "raise"
362-
363-
x = T.vector()
364-
365-
try:
366-
for x_val in op.get_debug_values(x):
367-
# this assert catches the case where we
368-
# erroneously get a value returned
369-
assert False
370-
raised = False
371-
except AttributeError:
372-
raised = True
373-
374-
# this assert catches the case where we got []
375-
# returned, and possibly issued a warning,
376-
# rather than raising an exception
377-
assert raised
329+
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
378330

379-
finally:
380-
config.compute_test_value = prev_value
381-
382-
383-
def test_debug_error_message():
384-
# tests that debug_error_message raises an
385-
# exception when it should.
386-
387-
prev_value = config.compute_test_value
388-
389-
for mode in ["ignore", "raise"]:
390-
391-
try:
392-
config.compute_test_value = mode
393-
394-
try:
395-
op.debug_error_message("msg")
396-
raised = False
397-
except ValueError:
398-
raised = True
399-
assert raised
400-
finally:
401-
config.compute_test_value = prev_value
331+
with pytest.raises(AttributeError):
332+
x = tt.vector()
333+
assert op.get_debug_values(x) == []

tests/gpuarray/test_multinomial.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
import tests.unittest_tools as utt
99

10+
from pickle import Unpickler
11+
1012
from theano import config, function, tensor
1113
from theano.compat import PY3
12-
from theano.misc.pkl_utils import CompatUnpickler
1314
from theano.sandbox import multinomial
1415
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
1516
from theano.gpuarray.multinomial import (
@@ -384,6 +385,6 @@ def test_unpickle_legacy_op():
384385

385386
if not PY3:
386387
with open(os.path.join(testfile_dir, fname), "r") as fp:
387-
u = CompatUnpickler(fp)
388+
u = Unpickler(fp)
388389
m = u.load()
389390
assert isinstance(m, GPUAChoiceFromUniform)

tests/gpuarray/test_pickle.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
import numpy as np
1515

16+
from pickle import Unpickler
17+
1618
from theano import config
17-
from theano.compat import PY3
18-
from theano.misc.pkl_utils import CompatUnpickler
1919

2020
from theano.gpuarray.type import ContextNotDefined
2121

@@ -37,10 +37,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag1():
3737
fname = "GpuArray.pkl"
3838

3939
with open(os.path.join(testfile_dir, fname), "rb") as fp:
40-
if PY3:
41-
u = CompatUnpickler(fp, encoding="latin1")
42-
else:
43-
u = CompatUnpickler(fp)
40+
u = Unpickler(fp, encoding="latin1")
4441
with pytest.raises((ImportError, ContextNotDefined)):
4542
u.load()
4643
finally:
@@ -56,10 +53,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag2():
5653
fname = "GpuArray.pkl"
5754

5855
with open(os.path.join(testfile_dir, fname), "rb") as fp:
59-
if PY3:
60-
u = CompatUnpickler(fp, encoding="latin1")
61-
else:
62-
u = CompatUnpickler(fp)
56+
u = Unpickler(fp, encoding="latin1")
6357
try:
6458
mat = u.load()
6559
except ImportError:

tests/gpuarray/test_type.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
pygpu = pytest.importorskip("pygpu")
77

8-
from theano.compat import PY3
8+
from pickle import Unpickler
9+
910
from theano import config
1011
from theano.compile import DeepCopyOp, Rebroadcast, ViewOp
11-
from theano.misc.pkl_utils import CompatUnpickler
1212
from theano.gpuarray.type import GpuArrayType, gpuarray_shared_constructor
1313

1414
from tests.gpuarray.config import test_ctx_name
@@ -122,10 +122,7 @@ def test_unpickle_gpuarray_as_numpy_ndarray_flag0():
122122
fname = "GpuArray.pkl"
123123

124124
with open(os.path.join(testfile_dir, fname), "rb") as fp:
125-
if PY3:
126-
u = CompatUnpickler(fp, encoding="latin1")
127-
else:
128-
u = CompatUnpickler(fp)
125+
u = Unpickler(fp, encoding="latin1")
129126
mat = u.load()
130127
assert isinstance(mat, pygpu.gpuarray.GpuArray)
131128
assert np.asarray(mat)[0] == -42.0

0 commit comments

Comments
 (0)