Skip to content

Commit 6d43bf4

Browse files
Raise warnings when test_val is accessed
Added pytest Future Warning in relavant tests Removed and replaced usage of test_value in JAX/Numba tests
1 parent ee4d4f7 commit 6d43bf4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+49513
-580
lines changed

coverage/coverage-.xml

Lines changed: 48682 additions & 0 deletions
Large diffs are not rendered by default.

pytensor/compile/sharedvalue.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Provide a simple user friendly API to PyTensor-managed memory."""
22

33
import copy
4+
import warnings
45
from contextlib import contextmanager
56
from functools import singledispatch
67
from typing import TYPE_CHECKING
@@ -134,6 +135,10 @@ def set_value(self, new_value, borrow=False):
134135
self.container.value = copy.deepcopy(new_value)
135136

136137
def get_test_value(self):
138+
warnings.warn(
139+
"test_value machinery is deprecated and will stop working in the future.",
140+
FutureWarning,
141+
)
137142
return self.get_value(borrow=True, return_internal_type=True)
138143

139144
def clone(self, **kwargs):

pytensor/configdefaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import sys
88
import textwrap
9+
import warnings
910
from pathlib import Path
1011

1112
import numpy as np
@@ -1378,6 +1379,12 @@ def add_caching_dir_configvars():
13781379
else:
13791380
gcc_version_str = "GCC_NOT_FOUND"
13801381

1382+
if config.compute_test_value != "off":
1383+
warnings.warn(
1384+
"test_value machinery is deprecated and will stop working in the future.",
1385+
FutureWarning,
1386+
)
1387+
13811388
# TODO: The caching dir resolution is a procedural mess of helper functions, local variables
13821389
# and config definitions. And the result is also not particularly pretty..
13831390
add_caching_dir_configvars()

pytensor/graph/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,10 @@ def __init__(self, type: _TypeType, data: Any, name: str | None = None):
784784
add_tag_trace(self)
785785

786786
def get_test_value(self):
787+
warnings.warn(
788+
"test_value machinery is deprecated and will stop working in the future.",
789+
FutureWarning,
790+
)
787791
return self.data
788792

789793
def signature(self):

pytensor/graph/op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,11 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
711711
if config.compute_test_value == "off":
712712
return []
713713

714+
warnings.warn(
715+
"test_value machinery is deprecated and will stop working in the future.",
716+
FutureWarning,
717+
)
718+
714719
rval = []
715720

716721
for i, arg in enumerate(args):

pytensor/graph/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import linecache
22
import sys
33
import traceback
4+
import warnings
45
from abc import ABCMeta
56
from collections.abc import Sequence
67
from io import StringIO
@@ -283,9 +284,19 @@ def info(self):
283284

284285
# These two methods have been added to help Mypy
285286
def __getattribute__(self, name):
287+
if name == "test_value":
288+
warnings.warn(
289+
"test_value machinery is deprecated and will stop working in the future.",
290+
FutureWarning,
291+
)
286292
return super().__getattribute__(name)
287293

288294
def __setattr__(self, name: str, value: Any) -> None:
295+
if name == "test_value":
296+
warnings.warn(
297+
"test_value machinery is deprecated and will stop working in the future.",
298+
FutureWarning,
299+
)
289300
self.__dict__[name] = value
290301

291302

@@ -300,6 +311,11 @@ def __init__(self, attr, attr_filter):
300311

301312
def __setattr__(self, attr, obj):
302313
if getattr(self, "attr", None) == attr:
314+
if attr == "test_value":
315+
warnings.warn(
316+
"test_value machinery is deprecated and will stop working in the future.",
317+
FutureWarning,
318+
)
303319
obj = self.attr_filter(obj)
304320

305321
return object.__setattr__(self, attr, obj)

pytensor/scalar/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def supports_c_code(self, inputs, outputs):
12221222
tmp_s_input.append(tmp)
12231223
mapping[ii] = tmp_s_input[-1]
12241224

1225-
with config.change_flags(compute_test_value="ignore"):
1225+
with config.change_flags(compute_test_value="off"):
12261226
s_op = self(*tmp_s_input, return_list=True)
12271227

12281228
# if the scalar_op don't have a c implementation,

tests/compile/test_builders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,12 @@ def test_infer_shape(self):
523523

524524
@config.change_flags(compute_test_value="raise")
525525
def test_compute_test_value(self):
526-
x = scalar("x")
527-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
528-
op = OpFromGraph([x], [x**3])
529-
y = scalar("y")
530-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
526+
with pytest.warns(FutureWarning):
527+
x = scalar("x")
528+
x.tag.test_value = np.array(1.0, dtype=config.floatX)
529+
op = OpFromGraph([x], [x**3])
530+
y = scalar("y")
531+
y.tag.test_value = np.array(1.0, dtype=config.floatX)
531532
f = op(y)
532533
grad_f = grad(f, y)
533534
assert grad_f.tag.test_value is not None

tests/compile/test_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ def cumprod(x):
3333

3434
def test_2arg(self):
3535
x = dmatrix("x")
36-
x.tag.test_value = np.zeros((2, 2))
3736
y = dvector("y")
38-
y.tag.test_value = [0, 0, 0, 0]
3937

4038
@as_op([dmatrix, dvector], dvector)
4139
def cumprod_plus(x, y):
@@ -49,9 +47,7 @@ def cumprod_plus(x, y):
4947

5048
def test_infer_shape(self):
5149
x = dmatrix("x")
52-
x.tag.test_value = np.zeros((2, 2))
5350
y = dvector("y")
54-
y.tag.test_value = [0, 0, 0, 0]
5551

5652
def infer_shape(fgraph, node, shapes):
5753
x, y = shapes

tests/graph/test_compute_test_value.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def perform(self, node, inputs, outputs):
6767

6868
test_input = SomeType()()
6969
orig_object = object()
70-
test_input.tag.test_value = orig_object
70+
with pytest.warns(FutureWarning):
71+
test_input.tag.test_value = orig_object
7172

7273
res = InplaceOp(False)(test_input)
7374
assert res.tag.test_value is orig_object
@@ -76,10 +77,11 @@ def perform(self, node, inputs, outputs):
7677
assert res.tag.test_value is not orig_object
7778

7879
def test_variable_only(self):
79-
x = matrix("x")
80-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
81-
y = matrix("y")
82-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
80+
with pytest.warns(FutureWarning):
81+
x = matrix("x")
82+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
83+
y = matrix("y")
84+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
8385

8486
# should work
8587
z = dot(x, y)
@@ -88,14 +90,16 @@ def test_variable_only(self):
8890
assert _allclose(f(x.tag.test_value, y.tag.test_value), z.tag.test_value)
8991

9092
# this test should fail
91-
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
93+
with pytest.warns(FutureWarning):
94+
y.tag.test_value = np.random.random((6, 5)).astype(config.floatX)
9295
with pytest.raises(ValueError):
9396
dot(x, y)
9497

9598
def test_compute_flag(self):
9699
x = matrix("x")
97100
y = matrix("y")
98-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
101+
with pytest.warns(FutureWarning):
102+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
99103

100104
# should skip computation of test value
101105
with config.change_flags(compute_test_value="off"):
@@ -111,10 +115,11 @@ def test_compute_flag(self):
111115
dot(x, y)
112116

113117
def test_string_var(self):
114-
x = matrix("x")
115-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
116-
y = matrix("y")
117-
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118+
with pytest.warns(FutureWarning):
119+
x = matrix("x")
120+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
121+
y = matrix("y")
122+
y.tag.test_value = np.random.random((4, 5)).astype(config.floatX)
118123

119124
z = pytensor.shared(np.random.random((5, 6)).astype(config.floatX))
120125

@@ -134,7 +139,8 @@ def f(x, y, z):
134139

135140
def test_shared(self):
136141
x = matrix("x")
137-
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
142+
with pytest.warns(FutureWarning):
143+
x.tag.test_value = np.random.random((3, 4)).astype(config.floatX)
138144
y = pytensor.shared(np.random.random((4, 6)).astype(config.floatX), "y")
139145

140146
# should work
@@ -190,30 +196,31 @@ def test_constant(self):
190196
def test_incorrect_type(self):
191197
x = vector("x")
192198
with pytest.raises(TypeError):
193-
# Incorrect shape for test value
194-
x.tag.test_value = np.empty((2, 2))
199+
with pytest.warns(FutureWarning):
200+
# Incorrect shape for test value
201+
x.tag.test_value = np.empty((2, 2))
195202

196203
x = fmatrix("x")
197204
with pytest.raises(TypeError):
198-
# Incorrect dtype (float64) for test value
199-
x.tag.test_value = np.random.random((3, 4))
205+
with pytest.warns(FutureWarning):
206+
# Incorrect dtype (float64) for test value
207+
x.tag.test_value = np.random.random((3, 4))
200208

201209
def test_overided_function(self):
202210
# We need to test those as they mess with Exception
203211
# And we don't want the exception to be changed.
204212
x = matrix()
205-
x.tag.test_value = np.zeros((2, 3), dtype=config.floatX)
206213
y = matrix()
207-
y.tag.test_value = np.zeros((2, 2), dtype=config.floatX)
208214
with pytest.raises(ValueError):
209215
x.__mul__(y)
210216

211217
def test_scan(self):
212218
# Test the compute_test_value mechanism Scan.
213219
k = iscalar("k")
214220
A = vector("A")
215-
k.tag.test_value = 3
216-
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
221+
with pytest.warns(FutureWarning):
222+
k.tag.test_value = 3
223+
A.tag.test_value = np.random.random((5,)).astype(config.floatX)
217224

218225
def fx(prior_result, A):
219226
return prior_result * A
@@ -233,8 +240,9 @@ def test_scan_err1(self):
233240
# This test should fail when building fx for the first time
234241
k = iscalar("k")
235242
A = matrix("A")
236-
k.tag.test_value = 3
237-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
243+
with pytest.warns(FutureWarning):
244+
k.tag.test_value = 3
245+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
238246

239247
def fx(prior_result, A):
240248
return dot(prior_result, A)
@@ -253,8 +261,9 @@ def test_scan_err2(self):
253261
# but when calling the scan's perform()
254262
k = iscalar("k")
255263
A = matrix("A")
256-
k.tag.test_value = 3
257-
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
264+
with pytest.warns(FutureWarning):
265+
k.tag.test_value = 3
266+
A.tag.test_value = np.random.random((5, 3)).astype(config.floatX)
258267

259268
def fx(prior_result, A):
260269
return dot(prior_result, A)
@@ -288,7 +297,8 @@ def perform(self, node, inputs, outputs):
288297
output[0] = input + 1
289298

290299
i = ps.int32("i")
291-
i.tag.test_value = 3
300+
with pytest.warns(FutureWarning):
301+
i.tag.test_value = 3
292302

293303
o = IncOnePython()(i)
294304

@@ -304,7 +314,8 @@ def perform(self, node, inputs, outputs):
304314
)
305315
def test_no_perform(self):
306316
i = ps.int32("i")
307-
i.tag.test_value = 3
317+
with pytest.warns(FutureWarning):
318+
i.tag.test_value = 3
308319

309320
# Class IncOneC is defined outside of the TestComputeTestValue
310321
# so it can be pickled and unpickled

tests/graph/test_fg.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,19 @@ def test_change_input(self):
241241

242242
@config.change_flags(compute_test_value="raise")
243243
def test_replace_test_value(self):
244-
var1 = MyVariable("var1")
245-
var1.tag.test_value = 1
246-
var2 = MyVariable("var2")
247-
var2.tag.test_value = 2
248-
var3 = op1(var2, var1)
249-
var4 = op2(var3, var2)
250-
var4.tag.test_value = np.array([1, 2])
251-
var5 = op3(var4, var2, var2)
252-
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
253-
254-
var6 = op3()
255-
var6.tag.test_value = np.array(0)
244+
with pytest.warns(FutureWarning):
245+
var1 = MyVariable("var1")
246+
var1.tag.test_value = 1
247+
var2 = MyVariable("var2")
248+
var2.tag.test_value = 2
249+
var3 = op1(var2, var1)
250+
var4 = op2(var3, var2)
251+
var4.tag.test_value = np.array([1, 2])
252+
var5 = op3(var4, var2, var2)
253+
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
254+
255+
var6 = op3()
256+
var6.tag.test_value = np.array(0)
256257

257258
assert var6.tag.test_value.shape != var4.tag.test_value.shape
258259

0 commit comments

Comments
 (0)