Skip to content

Commit 1aa9cb6

Browse files
committed
Added rewrite for diag of kronecker product
1 parent a3f0a4e commit 1aa9cb6

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from pytensor.scalar.basic import Mul
1212
from pytensor.tensor.basic import (
1313
AllocDiag,
14+
ExtractDiag,
1415
Eye,
1516
TensorVariable,
17+
diag,
1618
diagonal,
1719
)
1820
from pytensor.tensor.blas import Dot22
1921
from pytensor.tensor.blockwise import Blockwise
2022
from pytensor.tensor.elemwise import DimShuffle, Elemwise
21-
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
23+
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
2224
from pytensor.tensor.nlinalg import (
2325
SVD,
2426
KroneckerProduct,
@@ -611,3 +613,20 @@ def rewrite_inv_inv(fgraph, node):
611613
):
612614
return None
613615
return [potential_inner_inv.inputs[0]]
616+
617+
618+
@register_canonicalize
619+
@register_stabilize
620+
@node_rewriter([ExtractDiag])
621+
def rewrite_diag_kronecker(fgraph, node):
622+
# Check for inner kron operation
623+
potential_kron = node.inputs[0].owner
624+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
625+
return None
626+
627+
# Find the matrices
628+
a, b = potential_kron.inputs
629+
diag_a, diag_b = diag(a), diag(b)
630+
outer_prod_as_vector = outer(diag_a, diag_b).flatten()
631+
632+
return [outer_prod_as_vector]

tests/tensor/rewriting/test_linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,26 @@ def get_pt_function(x, op_name):
568568
op2 = get_pt_function(op1, inv_op_2)
569569
rewritten_out = rewrite_graph(op2)
570570
assert rewritten_out == x
571+
572+
573+
def test_diag_kronecker_rewrite():
574+
a, b = pt.dmatrices("a", "b")
575+
kron_prod = pt.linalg.kron(a, b)
576+
diag_kron_prod = pt.diag(kron_prod)
577+
f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN")
578+
579+
# Rewrite Test
580+
nodes = f_rewritten.maker.fgraph.apply_nodes
581+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
582+
583+
# Value Test
584+
a_test, b_test = np.random.rand(2, 20, 20)
585+
kron_prod_test = np.kron(a_test, b_test)
586+
diag_kron_prod_test = np.diag(kron_prod_test)
587+
rewritten_val = f_rewritten(a_test, b_test)
588+
assert_allclose(
589+
diag_kron_prod_test,
590+
rewritten_val,
591+
atol=1e-3 if config.floatX == "float32" else 1e-8,
592+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
593+
)

0 commit comments

Comments
 (0)