Skip to content

Commit 8b740f6

Browse files
Added pytest Future Warning in relavant tests
1 parent f3abb76 commit 8b740f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+666
-589
lines changed

pytensor/configdefaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import socket
77
import sys
88
import textwrap
9+
import warnings
910

1011
import numpy as np
1112
from setuptools._distutils.spawn import find_executable
@@ -1447,6 +1448,12 @@ def add_caching_dir_configvars():
14471448
else:
14481449
gcc_version_str = "GCC_NOT_FOUND"
14491450

1451+
if config.compute_test_value != "off":
1452+
warnings.warn(
1453+
"test_value machinery is deprecated and will stop working in the future.",
1454+
FutureWarning,
1455+
)
1456+
14501457
# TODO: The caching dir resolution is a procedural mess of helper functions, local variables
14511458
# and config definitions. And the result is also not particularly pretty..
14521459
add_caching_dir_configvars()

pytensor/graph/basic.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,12 @@ def __init__(
451451

452452
self.tag = ValidatingScratchpad("test_value", type.filter)
453453

454+
# if hasattr(self.tag, "test_value"):
455+
# warnings.warn(
456+
# "test_value machinery is deprecated and will stop working in the future.",
457+
# FutureWarning,
458+
# )
459+
454460
self.type = type
455461

456462
self._owner = owner
@@ -479,10 +485,7 @@ def get_test_value(self):
479485
if not hasattr(self.tag, "test_value"):
480486
detailed_err_msg = get_variable_trace_string(self)
481487
raise TestValueError(f"{self} has no test value {detailed_err_msg}")
482-
warnings.warn(
483-
"test_value machinery is deprecated and will stop working in the future.",
484-
FutureWarning,
485-
)
488+
486489
return self.tag.test_value
487490

488491
def __str__(self):

pytensor/graph/fg.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
22

33
import time
4-
import warnings
54
from collections import OrderedDict
65
from collections.abc import Iterable, Sequence
76
from typing import TYPE_CHECKING, Any, Literal, Union, cast
@@ -494,10 +493,6 @@ def replace(
494493
return
495494

496495
if config.compute_test_value != "off":
497-
warnings.warn(
498-
"test_value machinery is deprecated and will stop working in the future.",
499-
FutureWarning,
500-
)
501496
try:
502497
tval = pytensor.graph.op.get_test_value(var)
503498
new_tval = pytensor.graph.op.get_test_value(new_var)

pytensor/graph/op.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,6 @@ def compute_test_value(node: Apply):
7070
"""
7171
# Gather the test values for each input of the node
7272

73-
warnings.warn(
74-
"compute_test_value is deprecated and will stop working in the future.",
75-
FutureWarning,
76-
)
77-
7873
storage_map = {}
7974
compute_map = {}
8075
for i, ins in enumerate(node.inputs):
@@ -307,10 +302,6 @@ def __call__(
307302
n.name = f"{name}_{i}"
308303

309304
if config.compute_test_value != "off":
310-
warnings.warn(
311-
"test_value machinery is deprecated and will stop working in the future.",
312-
FutureWarning,
313-
)
314305
compute_test_value(node)
315306

316307
if self.default_output is not None:
@@ -721,11 +712,6 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
721712
if config.compute_test_value == "off":
722713
return []
723714

724-
warnings.warn(
725-
"test_value machinery is deprecated and will stop working in the future.",
726-
FutureWarning,
727-
)
728-
729715
rval = []
730716

731717
for i, arg in enumerate(args):

pytensor/graph/rewriting/basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,6 @@ def transform(self, fgraph, node, *args, **kwargs):
983983
if isinstance(input, pytensor.compile.SharedVariable):
984984
pass
985985
elif hasattr(input.tag, "test_value"):
986-
warnings.warn(
987-
"compute_test_value is deprecated and will stop working in the future.",
988-
FutureWarning,
989-
)
990986
givens[input] = pytensor.shared(
991987
input.type.filter(input.tag.test_value),
992988
input.name,

pytensor/graph/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import linecache
22
import sys
33
import traceback
4+
import warnings
45
from abc import ABCMeta
56
from collections.abc import Sequence
67
from io import StringIO
@@ -283,9 +284,19 @@ def info(self):
283284

284285
# These two methods have been added to help Mypy
285286
def __getattribute__(self, name):
287+
if name == "test_value":
288+
warnings.warn(
289+
"test_value machinery is deprecated and will stop working in the future.",
290+
FutureWarning,
291+
)
286292
return super().__getattribute__(name)
287293

288294
def __setattr__(self, name: str, value: Any) -> None:
295+
if name == "test_value":
296+
warnings.warn(
297+
"test_value machinery is deprecated and will stop working in the future.",
298+
FutureWarning,
299+
)
289300
self.__dict__[name] = value
290301

291302

@@ -300,6 +311,11 @@ def __init__(self, attr, attr_filter):
300311

301312
def __setattr__(self, attr, obj):
302313
if getattr(self, "attr", None) == attr:
314+
if attr == "test_value":
315+
warnings.warn(
316+
"test_value machinery is deprecated and will stop working in the future.",
317+
FutureWarning,
318+
)
303319
obj = self.attr_filter(obj)
304320

305321
return object.__setattr__(self, attr, obj)

pytensor/misc/pkl_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pickle
1010
import sys
1111
import tempfile
12-
import warnings
1312
import zipfile
1413
from collections import defaultdict
1514
from contextlib import closing
@@ -62,10 +61,6 @@ class StripPickler(Pickler):
6261
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
6362
# Can't use super as Pickler isn't a new style class
6463
super().__init__(file, protocol)
65-
warnings.warn(
66-
"compute_test_value is deprecated and will stop working in the future.",
67-
FutureWarning,
68-
)
6964
self.tag_to_remove = ["trace", "test_value"]
7065
if extra_tag_to_remove:
7166
self.tag_to_remove.extend(extra_tag_to_remove)

pytensor/scalar/basic.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import builtins
1414
import math
15-
import warnings
1615
from collections.abc import Callable, Mapping
1716
from copy import copy
1817
from itertools import chain
@@ -4415,10 +4414,6 @@ def apply(self, fgraph):
44154414
if i.dtype == "float16":
44164415
mapping[i] = get_scalar_type("float32")()
44174416
if hasattr(i.tag, "test_value"):
4418-
warnings.warn(
4419-
"test_value machinery is deprecated and will stop working in the future.",
4420-
FutureWarning,
4421-
)
44224417
mapping[i].tag.test_value = i.tag.test_value
44234418
else:
44244419
mapping[i] = i

pytensor/scan/basic.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,6 @@ def wrap_into_list(x):
598598

599599
# Try to transfer test_value to the new variable
600600
if config.compute_test_value != "off":
601-
warnings.warn(
602-
"test_value machinery is deprecated and will stop working in the future.",
603-
FutureWarning,
604-
)
605601
try:
606602
nw_slice.tag.test_value = get_test_value(_seq_val_slice)
607603
except TestValueError:
@@ -729,10 +725,6 @@ def wrap_into_list(x):
729725

730726
# Try to transfer test_value to the new variable
731727
if config.compute_test_value != "off":
732-
warnings.warn(
733-
"test_value machinery is deprecated and will stop working in the future.",
734-
FutureWarning,
735-
)
736728
try:
737729
arg.tag.test_value = get_test_value(actual_arg)
738730
except TestValueError:
@@ -788,10 +780,6 @@ def wrap_into_list(x):
788780

789781
# Try to transfer test_value to the new variable
790782
if config.compute_test_value != "off":
791-
warnings.warn(
792-
"test_value machinery is deprecated and will stop working in the future.",
793-
FutureWarning,
794-
)
795783
try:
796784
nw_slice.tag.test_value = get_test_value(_init_out_var_slice)
797785
except TestValueError:

pytensor/scan/op.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import dataclasses
4747
import logging
4848
import time
49-
import warnings
5049
from collections import OrderedDict
5150
from collections.abc import Callable, Iterable
5251
from copy import copy
@@ -2651,10 +2650,6 @@ def compute_all_gradients(known_grads):
26512650
# fct add and we want to keep it for all Scan op. This is
26522651
# used in T_Scan.test_grad_multiple_outs_taps to test
26532652
# that.
2654-
warnings.warn(
2655-
"test_value machinery is deprecated and will stop working in the future.",
2656-
FutureWarning,
2657-
)
26582653
if info.as_while:
26592654
n = n_steps.tag.test_value
26602655
else:

pytensor/scan/rewriting.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import copy
44
import dataclasses
5-
import warnings
65
from itertools import chain
76
from sys import maxsize
87
from typing import cast
@@ -306,10 +305,6 @@ def add_to_replace(y):
306305
pushed_out_node = nd.op.make_node(*new_inputs)
307306

308307
if config.compute_test_value != "off":
309-
warnings.warn(
310-
"test_value machinery is deprecated and will stop working in the future.",
311-
FutureWarning,
312-
)
313308
compute_test_value(pushed_out_node)
314309

315310
# Step 2. Create variables to replace the old outputs of the node
@@ -516,10 +511,6 @@ def add_to_replace(y):
516511
nw_outer_node = nd.op.make_node(*outside_ins)
517512

518513
if config.compute_test_value != "off":
519-
warnings.warn(
520-
"test_value machinery is deprecated and will stop working in the future.",
521-
FutureWarning,
522-
)
523514
compute_test_value(nw_outer_node)
524515

525516
# Step 2. Create variables for replacements
@@ -554,10 +545,6 @@ def add_to_replace(y):
554545
replace_with_out.append(new_outer)
555546

556547
if hasattr(new_outer.tag, "test_value"):
557-
warnings.warn(
558-
"test_value machinery is deprecated and will stop working in the future.",
559-
FutureWarning,
560-
)
561548
new_sh = new_outer.tag.test_value.shape
562549
ref_sh = (outside_ins.tag.test_value.shape[0],)
563550
ref_sh += nd.outputs[0].tag.test_value.shape
@@ -995,10 +982,6 @@ def attempt_scan_inplace(
995982
new_lsi = inp.owner.op.make_node(*inp.owner.inputs)
996983

997984
if config.compute_test_value != "off":
998-
warnings.warn(
999-
"test_value machinery is deprecated and will stop working in the future.",
1000-
FutureWarning,
1001-
)
1002985
compute_test_value(new_lsi)
1003986

1004987
new_lsi_out = new_lsi.outputs

pytensor/scan/utils.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import copy
44
import dataclasses
55
import logging
6-
import warnings
76
from collections import OrderedDict, namedtuple
87
from collections.abc import Callable, Sequence
98
from itertools import chain
@@ -75,10 +74,6 @@ def safe_new(
7574
nw_x.name = nw_name
7675
if config.compute_test_value != "off":
7776
# Copy test value, cast it if necessary
78-
warnings.warn(
79-
"test_value machinery is deprecated and will stop working in the future.",
80-
FutureWarning,
81-
)
8277
try:
8378
x_test_value = get_test_value(x)
8479
except TestValueError:
@@ -109,10 +104,6 @@ def safe_new(
109104
# between test values, due to inplace operations for instance. This may
110105
# not be the most efficient memory-wise, though.
111106
if config.compute_test_value != "off":
112-
warnings.warn(
113-
"test_value machinery is deprecated and will stop working in the future.",
114-
FutureWarning,
115-
)
116107
try:
117108
nw_x.tag.test_value = copy.deepcopy(get_test_value(x))
118109
except TestValueError:

pytensor/tensor/blas.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
import logging
7979
import os
8080
import time
81-
import warnings
8281

8382
import numpy as np
8483

@@ -1968,10 +1967,6 @@ def R_op(self, inputs, eval_points):
19681967
test_values_enabled = config.compute_test_value != "off"
19691968

19701969
if test_values_enabled:
1971-
warnings.warn(
1972-
"test_value machinery is deprecated and will stop working in the future.",
1973-
FutureWarning,
1974-
)
19751970
try:
19761971
iv0 = pytensor.graph.op.get_test_value(inputs[0])
19771972
except TestValueError:

pytensor/tensor/random/rewriting/basic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from itertools import chain
32

43
from pytensor.compile import optdb
@@ -109,10 +108,6 @@ def local_rv_size_lift(fgraph, node):
109108
new_node = node.op.make_node(rng, None, *dist_params)
110109

111110
if config.compute_test_value != "off":
112-
warnings.warn(
113-
"test_value machinery is deprecated and will stop working in the future.",
114-
FutureWarning,
115-
)
116111
compute_test_value(new_node)
117112

118113
return new_node.outputs
@@ -192,10 +187,6 @@ def local_dimshuffle_rv_lift(fgraph, node):
192187
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
193188

194189
if config.compute_test_value != "off":
195-
warnings.warn(
196-
"test_value machinery is deprecated and will stop working in the future.",
197-
FutureWarning,
198-
)
199190
compute_test_value(new_node)
200191

201192
out = new_node.outputs[1]

tests/compile/test_builders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,12 @@ def test_infer_shape(self):
523523

524524
@config.change_flags(compute_test_value="raise")
525525
def test_compute_test_value(self):
526-
x = scalar("x")
527-
x.tag.test_value = np.array(1.0, dtype=config.floatX)
528-
op = OpFromGraph([x], [x**3])
529-
y = scalar("y")
530-
y.tag.test_value = np.array(1.0, dtype=config.floatX)
526+
with pytest.warns(FutureWarning):
527+
x = scalar("x")
528+
x.tag.test_value = np.array(1.0, dtype=config.floatX)
529+
op = OpFromGraph([x], [x**3])
530+
y = scalar("y")
531+
y.tag.test_value = np.array(1.0, dtype=config.floatX)
531532
f = op(y)
532533
grad_f = grad(f, y)
533534
assert grad_f.tag.test_value is not None

0 commit comments

Comments
 (0)