Skip to content

Commit 736da90

Browse files
ricardoV94jessegrabowski
authored andcommitted
Avoid duplicated inputs in KroneckerProduct OpFromGraph
1 parent 9c96ce2 commit 736da90

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

pytensor/compile/builders.py

+9
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,15 @@ def __init__(
400400
Check :func:`pytensor.function` for more arguments, only works when not
401401
inline.
402402
"""
403+
ignore_unused_inputs = kwargs.get("on_unused_input", False) == "ignore"
404+
if not ignore_unused_inputs and len(inputs) != len(set(inputs)):
405+
var_counts = {var: inputs.count(var) for var in inputs}
406+
duplicated_inputs = [var for var, count in var_counts.items() if count > 1]
407+
raise ValueError(
408+
f"There following variables were provided more than once as inputs to the OpFromGraph, resulting in an "
409+
f"invalid graph: {duplicated_inputs}. Use dummy variables or var.copy() to distinguish "
410+
f"variables when creating the OpFromGraph graph."
411+
)
403412

404413
if not (isinstance(inputs, list) and isinstance(outputs, list)):
405414
raise TypeError("Inputs and outputs must be lists")

pytensor/tensor/nlinalg.py

+5
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,11 @@ def kron(a, b):
10341034
"""
10351035
a = as_tensor_variable(a)
10361036
b = as_tensor_variable(b)
1037+
1038+
if a is b:
1039+
# In case a is the same as b, we need a different variable to build the OFG
1040+
b = a.copy()
1041+
10371042
if a.ndim + b.ndim <= 2:
10381043
raise TypeError(
10391044
"kron: inputs dimensions must sum to 3 or more. "

tests/compile/test_builders.py

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def test_grad_grad(self, cls_ofg):
118118
f = op(x, y, z)
119119
f = f - grad(pt_sum(f), y)
120120
f = f - grad(pt_sum(f), y)
121+
121122
fn = function([x, y, z], f)
122123
xv = np.ones((2, 2), dtype=config.floatX)
123124
yv = np.ones((2, 2), dtype=config.floatX) * 3

tests/tensor/test_slinalg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,8 @@ def test_expm_grad_3():
514514
def test_solve_discrete_lyapunov_via_direct_real():
515515
N = 5
516516
rng = np.random.default_rng(utt.fetch_seed())
517-
a = pt.dmatrix()
518-
q = pt.dmatrix()
517+
a = pt.dmatrix("a")
518+
q = pt.dmatrix("q")
519519
f = function([a, q], [solve_discrete_lyapunov(a, q, method="direct")])
520520

521521
A = rng.normal(size=(N, N))

0 commit comments

Comments
 (0)