Skip to content

Commit 7a0ea76

Browse files
committed
Wrap literal constants in parenthesis in c-impl of ScalarOps
1 parent 718b1f7 commit 7a0ea76

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

pytensor/scalar/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4349,7 +4349,7 @@ def c_code_template(self):
43494349
if var not in self.fgraph.inputs:
43504350
# This is an orphan
43514351
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
4352-
subd[var] = var.type.c_literal(var.data)
4352+
subd[var] = f"({var.type.c_literal(var.data)})"
43534353
else:
43544354
raise ValueError(
43554355
"All orphans in the fgraph to Composite must"
@@ -4408,7 +4408,7 @@ def c_code(self, node, nodename, inames, onames, sub):
44084408
return self.c_code_template % d
44094409

44104410
def c_code_cache_version_outer(self) -> tuple[int, ...]:
4411-
return (4,)
4411+
return (5,)
44124412

44134413

44144414
class Compositef32:

pytensor/scalar/loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def c_code_template(self):
239239
if var not in self.fgraph.inputs:
240240
# This is an orphan
241241
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
242-
subd[var] = var.type.c_literal(var.data)
242+
subd[var] = f"({var.type.c_literal(var.data)})"
243243
else:
244244
raise ValueError(
245245
"All orphans in the fgraph to ScalarLoop must"
@@ -342,4 +342,4 @@ def c_code(self, node, nodename, inames, onames, sub):
342342
return res
343343

344344
def c_code_cache_version_outer(self):
345-
return (2,)
345+
return (3,)

tests/scalar/test_basic.py

+17
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
floats,
3737
int8,
3838
int32,
39+
int64,
3940
ints,
4041
invert,
4142
log,
@@ -44,6 +45,7 @@
4445
log10,
4546
mean,
4647
mul,
48+
neg,
4749
neq,
4850
rad2deg,
4951
reciprocal,
@@ -156,6 +158,21 @@ def checker(x, y):
156158
(literal_value + test_y) * (test_x / test_y),
157159
)
158160

161+
def test_negative_constant(self):
162+
# Test that a negative constant is wrapped in parentheses to avoid confusing - (unary minus) and -- (decrement)
163+
x = int64("x")
164+
e = neg(constant(-1.5)) % x
165+
comp_op = Composite([x], [e])
166+
comp_node = comp_op.make_node(x)
167+
168+
c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
169+
assert "-1.5" in c_code
170+
171+
g = FunctionGraph([x], [comp_node.out])
172+
fn = make_function(DualLinker().accept(g))
173+
assert fn(2) == 1.5
174+
assert fn(1) == 0.5
175+
159176
def test_many_outputs(self):
160177
x, y, z = floats("xyz")
161178
e0 = x + y + z

0 commit comments

Comments
 (0)