Skip to content

Commit 6f9352c

Browse files
Added pytest Future Warning in relavant tests
1 parent f3abb76 commit 6f9352c

36 files changed

+656
-499
lines changed

pytensor/graph/basic.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,12 @@ def __init__(
451451

452452
self.tag = ValidatingScratchpad("test_value", type.filter)
453453

454+
if hasattr(self.tag, "test_value"):
455+
warnings.warn(
456+
"test_value machinery is deprecated and will stop working in the future.",
457+
FutureWarning,
458+
)
459+
454460
self.type = type
455461

456462
self._owner = owner
@@ -479,10 +485,7 @@ def get_test_value(self):
479485
if not hasattr(self.tag, "test_value"):
480486
detailed_err_msg = get_variable_trace_string(self)
481487
raise TestValueError(f"{self} has no test value {detailed_err_msg}")
482-
warnings.warn(
483-
"test_value machinery is deprecated and will stop working in the future.",
484-
FutureWarning,
485-
)
488+
486489
return self.tag.test_value
487490

488491
def __str__(self):

pytensor/graph/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import linecache
22
import sys
33
import traceback
4+
import warnings
45
from abc import ABCMeta
56
from collections.abc import Sequence
67
from io import StringIO
@@ -283,9 +284,19 @@ def info(self):
283284

284285
# These two methods have been added to help Mypy
285286
def __getattribute__(self, name):
287+
if name == "test_value":
288+
warnings.warn(
289+
"test_value machinery is deprecated and will stop working in the future.",
290+
FutureWarning,
291+
)
286292
return super().__getattribute__(name)
287293

288294
def __setattr__(self, name: str, value: Any) -> None:
295+
if name == "test_value":
296+
warnings.warn(
297+
"test_value machinery is deprecated and will stop working in the future.",
298+
FutureWarning,
299+
)
289300
self.__dict__[name] = value
290301

291302

@@ -300,6 +311,11 @@ def __init__(self, attr, attr_filter):
300311

301312
def __setattr__(self, attr, obj):
302313
if getattr(self, "attr", None) == attr:
314+
if attr == "test_value":
315+
warnings.warn(
316+
"test_value machinery is deprecated and will stop working in the future.",
317+
FutureWarning,
318+
)
303319
obj = self.attr_filter(obj)
304320

305321
return object.__setattr__(self, attr, obj)

tests/compile/test_builders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,12 @@ def test_infer_shape(self):
523523

524524
@config.change_flags(compute_test_value="raise")
525525
def test_compute_test_value(self):
526-
x = scalar("x")
527-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
528-
op = OpFromGraph([x], [x**3])
529-
y = scalar("y")
530-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
526+
with pytest.warns(FutureWarning):
527+
x = scalar("x")
528+
x.tag.test_value = np.array(1.0, dtype=config.floatX)
529+
op = OpFromGraph([x], [x**3])
530+
y = scalar("y")
531+
y.tag.test_value = np.array(1.0, dtype=config.floatX)
531532
f = op(y)
532533
grad_f = grad(f, y)
533534
assert grad_f.tag.test_value is not None

tests/compile/test_ops.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pickle
22

33
import numpy as np
4+
import pytest
45

56
from pytensor import function
67
from pytensor.compile.ops import as_op
@@ -32,10 +33,11 @@ def cumprod(x):
3233
assert np.allclose(r, r0), (r, r0)
3334

3435
def test_2arg(self):
35-
x = dmatrix("x")
36-
x.tag.test_value = np.zeros((2, 2))
37-
y = dvector("y")
38-
y.tag.test_value = [0, 0, 0, 0]
36+
with pytest.warns(FutureWarning):
37+
x = dmatrix("x")
38+
x.tag.test_value = np.zeros((2, 2))
39+
y = dvector("y")
40+
y.tag.test_value = [0, 0, 0, 0]
3941

4042
@as_op([dmatrix, dvector], dvector)
4143
def cumprod_plus(x, y):
@@ -48,10 +50,11 @@ def cumprod_plus(x, y):
4850
assert np.allclose(r, r0), (r, r0)
4951

5052
def test_infer_shape(self):
51-
x = dmatrix("x")
52-
x.tag.test_value = np.zeros((2, 2))
53-
y = dvector("y")
54-
y.tag.test_value = [0, 0, 0, 0]
53+
with pytest.warns(FutureWarning):
54+
x = dmatrix("x")
55+
x.tag.test_value = np.zeros((2, 2))
56+
y = dvector("y")
57+
y.tag.test_value = [0, 0, 0, 0]
5558

5659
def infer_shape(fgraph, node, shapes):
5760
x, y = shapes

tests/graph/test_compute_test_value.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def perform(self, node, inputs, outputs):
6767

6868
test_input = SomeType()()
6969
orig_object = object()
70-
test_input.tag.test_value = orig_object
70+
with pytest.warns(FutureWarning):
71+
test_input.tag.test_value = orig_object
7172

7273
res = InplaceOp(False)(test_input)
7374
assert res.tag.test_value is orig_object
@@ -76,10 +77,11 @@ def perform(self, node, inputs, outputs):
7677
assert res.tag.test_value is not orig_object
7778

7879
def test_variable_only(self):
79-
x = matrix("x")
80-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
81-
y = matrix("y")
82-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
80+
with pytest.warns(FutureWarning):
81+
x = matrix("x")
82+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
83+
y = matrix("y")
84+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
8385

8486
# should work
8587
z = dot(x, y)
@@ -88,14 +90,16 @@ def test_variable_only(self):
8890
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)
8991

9092
# this test should fail
91-
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
93+
with pytest.warns(FutureWarning):
94+
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
9295
with pytest.raises(ValueError):
9396
dot(x, y)
9497

9598
def test_compute_flag(self):
9699
x = matrix("x")
97100
y = matrix("y")
98-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
101+
with pytest.warns(FutureWarning):
102+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
99103

100104
# should skip computation of test value
101105
with config.change_flags(compute_test_value="off"):
@@ -111,10 +115,11 @@ def test_compute_flag(self):
111115
dot(x, y)
112116

113117
def test_string_var(self):
114-
x = matrix("x")
115-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
116-
y = matrix("y")
117-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118+
with pytest.warns(FutureWarning):
119+
x = matrix("x")
120+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
121+
y = matrix("y")
122+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118123

119124
z = pytensor.shared(np.random.random((5, 6)).astype(config.floatX))
120125

@@ -134,7 +139,8 @@ def f(x, y, z):
134139

135140
def test_shared(self):
136141
x = matrix("x")
137-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
142+
with pytest.warns(FutureWarning):
143+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
138144
y = pytensor.shared(np.random.random((4, 6)).astype(config.floatX), "y")
139145

140146
# should work
@@ -189,31 +195,34 @@ def test_constant(self):
189195

190196
def test_incorrect_type(self):
191197
x = vector("x")
192-
with pytest.raises(TypeError):
198+
with pytest.raises(TypeError) and pytest.warns(FutureWarning):
193199
# Incorrect shape for test value
194200
x.tag.test_value = np.empty((2, 2))
195201

196202
x = fmatrix("x")
197-
with pytest.raises(TypeError):
203+
with pytest.raises(TypeError) and pytest.warns(FutureWarning):
198204
# Incorrect dtype (float64) for test value
199205
x.tag.test_value = np.random.random((3, 4))
200206

201207
def test_overided_function(self):
202208
# We need to test those as they mess with Exception
203209
# And we don't want the exception to be changed.
204-
x = matrix()
205-
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
206-
y = matrix()
207-
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
210+
211+
with pytest.warns(FutureWarning):
212+
x = matrix()
213+
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
214+
y = matrix()
215+
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
208216
with pytest.raises(ValueError):
209217
x.__mul__(y)
210218

211219
def test_scan(self):
212220
# Test the compute_test_value mechanism Scan.
213221
k = iscalar("k")
214222
A = vector("A")
215-
k.tag.test_value = 3
216-
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
223+
with pytest.warns(FutureWarning):
224+
k.tag.test_value = 3
225+
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
217226

218227
def fx(prior_result, A):
219228
return prior_result * A
@@ -233,8 +242,9 @@ def test_scan_err1(self):
233242
# This test should fail when building fx for the first time
234243
k = iscalar("k")
235244
A = matrix("A")
236-
k.tag.test_value = 3
237-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
245+
with pytest.warns(FutureWarning):
246+
k.tag.test_value = 3
247+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
238248

239249
def fx(prior_result, A):
240250
return dot(prior_result, A)
@@ -253,8 +263,9 @@ def test_scan_err2(self):
253263
# but when calling the scan's perform()
254264
k = iscalar("k")
255265
A = matrix("A")
256-
k.tag.test_value = 3
257-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
266+
with pytest.warns(FutureWarning):
267+
k.tag.test_value = 3
268+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
258269

259270
def fx(prior_result, A):
260271
return dot(prior_result, A)
@@ -288,7 +299,8 @@ def perform(self, node, inputs, outputs):
288299
output[0] = input + 1
289300

290301
i = ps.int32("i")
291-
i.tag.test_value = 3
302+
with pytest.warns(FutureWarning):
303+
i.tag.test_value = 3
292304

293305
o = IncOnePython()(i)
294306

@@ -304,7 +316,8 @@ def perform(self, node, inputs, outputs):
304316
)
305317
def test_no_perform(self):
306318
i = ps.int32("i")
307-
i.tag.test_value = 3
319+
with pytest.warns(FutureWarning):
320+
i.tag.test_value = 3
308321

309322
# Class IncOneC is defined outside of the TestComputeTestValue
310323
# so it can be pickled and unpickled

tests/graph/test_fg.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,19 @@ def test_change_input(self):
232232

233233
@config.change_flags(compute_test_value="raise")
234234
def test_replace_test_value(self):
235-
var1 = MyVariable("var1")
236-
var1.tag.test_value = 1
237-
var2 = MyVariable("var2")
238-
var2.tag.test_value = 2
239-
var3 = op1(var2, var1)
240-
var4 = op2(var3, var2)
241-
var4.tag.test_value = np.array([1, 2])
242-
var5 = op3(var4, var2, var2)
243-
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
244-
245-
var6 = op3()
246-
var6.tag.test_value = np.array(0)
235+
with pytest.warns(FutureWarning):
236+
var1 = MyVariable("var1")
237+
var1.tag.test_value = 1
238+
var2 = MyVariable("var2")
239+
var2.tag.test_value = 2
240+
var3 = op1(var2, var1)
241+
var4 = op2(var3, var2)
242+
var4.tag.test_value = np.array([1, 2])
243+
var5 = op3(var4, var2, var2)
244+
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
245+
246+
var6 = op3()
247+
var6.tag.test_value = np.array(0)
247248

248249
assert var6.tag.test_value.shape != var4.tag.test_value.shape
249250

tests/graph/test_op.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,34 +131,39 @@ def perform(self, node, inputs, outputs):
131131

132132

133133
def test_test_value_python_objects():
134-
for x in ([0, 1, 2], 0, 0.5, 1):
135-
assert np.all(op.get_test_value(x) == x)
134+
with pytest.warns(FutureWarning):
135+
for x in ([0, 1, 2], 0, 0.5, 1):
136+
assert np.all(op.get_test_value(x) == x)
136137

137138

138139
def test_test_value_ndarray():
139140
x = np.zeros((5, 5))
140-
v = op.get_test_value(x)
141+
with pytest.warns(FutureWarning):
142+
v = op.get_test_value(x)
141143
assert np.all(v == x)
142144

143145

144146
def test_test_value_constant():
145147
x = pt.as_tensor_variable(np.zeros((5, 5)))
146-
v = op.get_test_value(x)
148+
with pytest.warns(FutureWarning):
149+
v = op.get_test_value(x)
147150

148151
assert np.all(v == np.zeros((5, 5)))
149152

150153

151154
def test_test_value_shared():
152155
x = shared(np.zeros((5, 5)))
153-
v = op.get_test_value(x)
156+
with pytest.warns(FutureWarning):
157+
v = op.get_test_value(x)
154158

155159
assert np.all(v == np.zeros((5, 5)))
156160

157161

158162
@config.change_flags(compute_test_value="raise")
159163
def test_test_value_op():
160164
x = log(np.ones((5, 5)))
161-
v = op.get_test_value(x)
165+
with pytest.warns(FutureWarning):
166+
v = op.get_test_value(x)
162167

163168
assert np.allclose(v, np.zeros((5, 5)))
164169

@@ -168,22 +173,26 @@ def test_get_test_values_no_debugger():
168173
"""Tests that `get_test_values` returns `[]` when debugger is off."""
169174

170175
x = vector()
171-
assert op.get_test_values(x) == []
176+
with pytest.warns(FutureWarning):
177+
assert op.get_test_values(x) == []
172178

173179

174180
@config.change_flags(compute_test_value="ignore")
175181
def test_get_test_values_ignore():
176182
"""Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
177183

178184
x = vector()
179-
assert op.get_test_values(x) == []
185+
with pytest.warns(FutureWarning):
186+
assert op.get_test_values(x) == []
180187

181188

182189
def test_get_test_values_success():
183190
"""Tests that `get_test_values` returns values when available (and the debugger is on)."""
184191

185192
for mode in ["ignore", "warn", "raise"]:
186-
with config.change_flags(compute_test_value=mode):
193+
with config.change_flags(compute_test_value=mode) and pytest.warns(
194+
FutureWarning
195+
):
187196
x = vector()
188197
x.tag.test_value = np.zeros((4,), dtype=config.floatX)
189198
y = np.zeros((5, 5))
@@ -203,7 +212,7 @@ def test_get_test_values_success():
203212
def test_get_test_values_exc():
204213
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""
205214

206-
with pytest.raises(TestValueError):
215+
with pytest.raises(TestValueError) and pytest.warns(FutureWarning):
207216
x = vector()
208217
assert op.get_test_values(x) == []
209218

0 commit comments

Comments
 (0)