Skip to content

Commit 21daf4e

Browse files
Validate test values using a tensor's type
1 parent 74ee82b commit 21daf4e

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

tests/gof/test_compute_test_value.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,16 @@ def test_constant(self):
167167

168168
@theano.change_flags(compute_test_value="raise")
169169
def test_incorrect_type(self):
170-
x = tt.fmatrix("x")
171-
# Incorrect dtype (float64) for test_value
172-
x.tag.test_value = np.random.rand(3, 4)
173-
y = tt.dmatrix("y")
174-
y.tag.test_value = np.random.rand(4, 5)
175170

171+
x = tt.vector("x")
176172
with pytest.raises(TypeError):
177-
tt.dot(x, y)
173+
# Incorrect shape for test value
174+
x.tag.test_value = np.empty((2, 2))
175+
176+
x = tt.fmatrix("x")
177+
with pytest.raises(TypeError):
178+
# Incorrect dtype (float64) for test value
179+
x.tag.test_value = np.random.rand(3, 4)
178180

179181
@theano.change_flags(compute_test_value="raise")
180182
def test_overided_function(self):

theano/gof/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class Variable(Node):
383383
def __init__(self, type, owner=None, index=None, name=None):
384384
super(Variable, self).__init__()
385385

386-
self.tag = utils.Scratchpad()
386+
self.tag = utils.ValidatingScratchpad("test_value", type.filter)
387387

388388
self.type = type
389389
if owner is not None and not isinstance(owner, Apply):

theano/gof/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,23 @@ def info(self):
259259
print(" %s: %s" % (k, v))
260260

261261

262+
class ValidatingScratchpad(Scratchpad):
263+
"""This `Scratchpad` validates attribute values."""
264+
265+
def __init__(self, attr, attr_filter):
266+
super().__init__()
267+
268+
object.__setattr__(self, "attr", attr)
269+
object.__setattr__(self, "attr_filter", attr_filter)
270+
271+
def __setattr__(self, attr, obj):
272+
273+
if getattr(self, "attr", None) == attr:
274+
obj = self.attr_filter(obj)
275+
276+
return object.__setattr__(self, attr, obj)
277+
278+
262279
class D:
263280
def __init__(self, **d):
264281
self.__dict__.update(d)

theano/tensor/opt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7743,7 +7743,10 @@ def local_fuse(node):
77437743
if tv.size > 0:
77447744
tmp.tag.test_value = tv.flatten()[0]
77457745
else:
7746-
tmp.tag.test_value = tv
7746+
_logger.warning(
7747+
"Cannot construct a scalar test value"
7748+
" from a test value with no size: {}".format(ii)
7749+
)
77477750
except AttributeError:
77487751
pass
77497752

0 commit comments

Comments
 (0)