Skip to content

Commit 979132c

Browse files
ricardoV94jessegrabowski
authored andcommitted
Avoid duplicated inputs in KroneckerProduct OpFromGraph
1 parent 567b8d3 commit 979132c

File tree

4 files changed

+33
-2
lines changed

4 files changed

+33
-2
lines changed

pytensor/compile/builders.py

+9
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,15 @@ def __init__(
400400
Check :func:`pytensor.function` for more arguments, only works when not
401401
inline.
402402
"""
403+
ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore"
404+
if not ignore_unused_inputs and len(inputs) != len(set(inputs)):
405+
var_counts = {var: inputs.count(var) for var in inputs}
406+
duplicated_inputs = [var for var, count in var_counts.items() if count > 1]
407+
raise ValueError(
408+
f"There following variables were provided more than once as inputs to the OpFromGraph, resulting in an "
409+
f"invalid graph: {duplicated_inputs}. Use dummy variables or var.copy() to distinguish "
410+
f"variables when creating the OpFromGraph graph."
411+
)
403412

404413
if not (isinstance(inputs, list) and isinstance(outputs, list)):
405414
raise TypeError("Inputs and outputs must be lists")

pytensor/tensor/nlinalg.py

+5
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,11 @@ def kron(a, b):
10341034
"""
10351035
a = as_tensor_variable(a)
10361036
b = as_tensor_variable(b)
1037+
1038+
if a is b:
1039+
# In case a is the same as b, we need a different variable to build the OFG
1040+
b = a.copy()
1041+
10371042
if a.ndim + b.ndim <= 2:
10381043
raise TypeError(
10391044
"kron: inputs dimensions must sum to 3 or more. "

tests/compile/test_builders.py

+17
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def test_grad_grad(self, cls_ofg):
118118
f = op(x, y, z)
119119
f = f - grad(pt_sum(f), y)
120120
f = f - grad(pt_sum(f), y)
121+
121122
fn = function([x, y, z], f)
122123
xv = np.ones((2, 2), dtype=config.floatX)
123124
yv = np.ones((2, 2), dtype=config.floatX) * 3
@@ -584,6 +585,22 @@ def test_explicit_input_from_shared(self):
584585
out = test_ofg(y, y)
585586
assert out.eval() == 4
586587

588+
def test_repeated_inputs(self):
589+
x = pt.dscalar("x")
590+
y = pt.dscalar("y")
591+
592+
with pytest.raises(
593+
ValueError,
594+
match="There following variables were provided more than once as inputs to the "
595+
"OpFromGraph",
596+
):
597+
OpFromGraph([x, x], [x + y])
598+
599+
# Test that repeated inputs will be allowed if unused inputs are ignored
600+
g = OpFromGraph([x, x, y], [x + y], on_unused_input="ignore")
601+
f = g(x, x, y)
602+
assert f.eval({x: 5, y: 5}) == 10
603+
587604

588605
@config.change_flags(floatX="float64")
589606
def test_debugprint():

tests/tensor/test_slinalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,8 @@ def test_expm_grad_3():
514514
def test_solve_discrete_lyapunov_via_direct_real():
515515
N = 5
516516
rng = np.random.default_rng(utt.fetch_seed())
517-
a = pt.dmatrix()
518-
q = pt.dmatrix()
517+
a = pt.dmatrix("a")
518+
q = pt.dmatrix("q")
519519
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
520520

521521
A = rng.normal(size=(N, N))

0 commit comments

Comments
 (0)