Skip to content

Commit 4e47797

Browse files
committed
Raise explicitly on Python methods that are incompatible with lazy variables
Notably changes the behavior of `__bool__` to always raise. Before there was a hack based on whether a variable had been compared to something before.
1 parent c22e79e commit 4e47797

File tree

7 files changed

+79
-43
lines changed

7 files changed

+79
-43
lines changed

pytensor/compile/function/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def std_fgraph(
198198
update_mapping = {}
199199
out_idx = len(output_specs)
200200
for idx, input_spec in enumerate(input_specs):
201-
if input_spec.update:
201+
if input_spec.update is not None:
202202
updates.append(input_spec.update)
203203
update_mapping[out_idx] = idx
204204
out_idx += 1
@@ -1195,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11951195
updated_fgraph_inputs = {
11961196
fgraph_i
11971197
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
1198-
if getattr(i, "update", False)
1198+
if getattr(i, "update", None) is not None
11991199
}
12001200

12011201
# We can't use fgraph.inputs as this don't include Constant Value.
@@ -1351,7 +1351,11 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13511351
ancestors(
13521352
(
13531353
[o.variable for o in outputs]
1354-
+ [i.update for i in inputs if getattr(i, "update", False)]
1354+
+ [
1355+
i.update
1356+
for i in inputs
1357+
if getattr(i, "update", None) is not None
1358+
]
13551359
),
13561360
blockers=[i.variable for i in inputs],
13571361
)

pytensor/scalar/basic.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,37 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
725725

726726

727727
class _scalar_py_operators:
728+
# These can't work because Python requires native output types
729+
def __bool__(self):
730+
raise TypeError(
731+
"ScalarVariable cannot be converted to Python boolean. "
732+
"Call `.astype(bool)` for the symbolic equivalent."
733+
)
734+
735+
def __index__(self):
736+
raise TypeError(
737+
"ScalarVariable cannot be converted to Python integer. "
738+
"Call `.astype(int)` for the symbolic equivalent."
739+
)
740+
741+
def __int__(self):
742+
raise TypeError(
743+
"ScalarVariable cannot be converted to Python integer. "
744+
"Call `.astype(int)` for the symbolic equivalent."
745+
)
746+
747+
def __float__(self):
748+
raise TypeError(
749+
"ScalarVariable cannot be converted to Python float. "
750+
"Call `.astype(float)` for the symbolic equivalent."
751+
)
752+
753+
def __complex__(self):
754+
raise TypeError(
755+
"ScalarVariable cannot be converted to Python complex number. "
756+
"Call `.astype(complex)` for the symbolic equivalent."
757+
)
758+
728759
# So that we can simplify checking code when we have a mixture of ScalarType
729760
# variables and Tensor variables
730761
ndim = 0

pytensor/scalar/loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def __init__(
6060
constant = []
6161
if not len(init) == len(update):
6262
raise ValueError("An update must be given for each init variable")
63-
if until:
63+
if until is not None:
6464
inputs, outputs = clone([*init, *constant], [*update, until])
6565
else:
6666
inputs, outputs = clone([*init, *constant], update)
6767

68-
self.is_while = bool(until)
68+
self.is_while = until is not None
6969
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
7070
self._validate_updates(self.inputs, self.outputs)
7171

pytensor/scan/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2582,7 +2582,7 @@ def compute_all_gradients(known_grads):
25822582

25832583
# mask inputs that get no gradients
25842584
for dx in range(len(dC_dinps_t)):
2585-
if not dC_dinps_t[dx]:
2585+
if dC_dinps_t[dx] is None:
25862586
dC_dinps_t[dx] = pt.zeros_like(diff_inputs[dx])
25872587
else:
25882588
disconnected_dC_dinps_t[dx] = False

pytensor/tensor/conv/abstract_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,7 +2198,7 @@ def __init__(
21982198
):
21992199
border_mode = "valid"
22002200

2201-
self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim)
2201+
self.imshp = tuple(imshp) if imshp is not None else (None,) * (2 + convdim)
22022202
for imshp_i in self.imshp:
22032203
if imshp_i is not None:
22042204
# Components of imshp should be constant or ints
@@ -2208,7 +2208,7 @@ def __init__(
22082208
raise ValueError(
22092209
"imshp should be None or a tuple of constant int values"
22102210
).with_traceback(sys.exc_info()[2])
2211-
if kshp:
2211+
if kshp is not None:
22122212
self.kshp = tuple(kshp)
22132213
else:
22142214
self.kshp = (None,) * ((2 + 2 * convdim) if unshared else (2 + convdim))

pytensor/tensor/rewriting/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1992,7 +1992,7 @@ def local_pow_canonicalize(fgraph, node):
19921992
# x ** 1 = x
19931993
new_out = broadcast_arrays(*node.inputs)[0]
19941994

1995-
if not new_out:
1995+
if new_out is None:
19961996
return
19971997

19981998
if new_out.dtype != node.out.dtype:

pytensor/tensor/variable.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,54 @@
2626

2727

2828
class _tensor_py_operators:
29+
# These can't work because Python requires native output types
30+
def __bool__(self):
31+
raise TypeError(
32+
"TensorVariable cannot be converted to Python boolean. "
33+
"Call `.astype(bool)` for the symbolic equivalent."
34+
)
35+
36+
def __index__(self):
37+
raise TypeError(
38+
"TensorVariable cannot be converted to Python integer. "
39+
"Call `.astype(int)` for the symbolic equivalent."
40+
)
41+
42+
def __int__(self):
43+
raise TypeError(
44+
"TensorVariable cannot be converted to Python integer. "
45+
"Call `.astype(int)` for the symbolic equivalent."
46+
)
47+
48+
def __float__(self):
49+
raise TypeError(
50+
"TensorVariables cannot be converted to Python float. "
51+
"Call `.astype(float)` for the symbolic equivalent."
52+
)
53+
54+
def __complex__(self):
55+
raise TypeError(
56+
"TensorVariables cannot be converted to Python complex number. "
57+
"Call `.astype(complex)` for the symbolic equivalent."
58+
)
59+
2960
def __abs__(self):
3061
return pt.math.abs(self)
3162

3263
def __neg__(self):
3364
return pt.math.neg(self)
3465

35-
# These won't work because Python requires an int return value
36-
# def __int__(self): return convert_to_int32(self)
37-
# def __float__(self): return convert_to_float64(self)
38-
# def __complex__(self): return convert_to_complex128(self)
39-
40-
_is_nonzero = True
41-
4266
def __lt__(self, other):
43-
rval = pt.math.lt(self, other)
44-
rval._is_nonzero = False
45-
return rval
67+
return pt.math.lt(self, other)
4668

4769
def __le__(self, other):
48-
rval = pt.math.le(self, other)
49-
rval._is_nonzero = False
50-
return rval
70+
return pt.math.le(self, other)
5171

5272
def __gt__(self, other):
53-
rval = pt.math.gt(self, other)
54-
rval._is_nonzero = False
55-
return rval
73+
return pt.math.gt(self, other)
5674

5775
def __ge__(self, other):
58-
rval = pt.math.ge(self, other)
59-
rval._is_nonzero = False
60-
return rval
61-
62-
def __bool__(self):
63-
# This is meant to prohibit stuff like a < b < c, which is internally
64-
# implemented as (a < b) and (b < c). The trouble with this is the
65-
# side-effect that checking for a non-NULL a by typing "if a: ..."
66-
# uses the same __nonzero__ method. We want these both to work, but
67-
# it seems impossible. Currently, all vars evaluate to nonzero except
68-
# the return values of comparison operators, which raise this
69-
# exception. If you can think of a better solution, go for it!
70-
#
71-
# __bool__ is Python 3.x data model. __nonzero__ is Python 2.x.
72-
if self._is_nonzero:
73-
return True
74-
else:
75-
raise TypeError("Variables do not support boolean operations.")
76+
return pt.math.ge(self, other)
7677

7778
def __invert__(self):
7879
return pt.math.invert(self)

0 commit comments

Comments
 (0)