Skip to content

Commit e96a770

Browse files
committed
fixed failing tests and added rewrite for pt.diag
1 parent f445c5f commit e96a770

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Callable
33
from typing import cast
44

5+
import pytensor.tensor as pt
56
from pytensor import Variable
67
from pytensor.graph import Apply, FunctionGraph
78
from pytensor.graph.rewriting.basic import (
@@ -652,13 +653,24 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
652653
@register_canonicalize
653654
@register_stabilize
654655
@node_rewriter([Blockwise])
655-
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
656+
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
656657
# Find whether cholesky op is being applied
657658
if not isinstance(node.op.core_op, Cholesky):
658659
return None
659660

660-
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
661661
inputs = node.inputs[0]
662+
# Check for use of pt.diag first
663+
if (
664+
inputs.owner
665+
and isinstance(inputs.owner.op, AllocDiag)
666+
and AllocDiag.is_offset_zero(inputs.owner)
667+
):
668+
cholesky_input = inputs.owner.inputs[0]
669+
if cholesky_input.type.ndim == 1:
670+
cholesky_val = pt.diag(cholesky_input**0.5)
671+
return [cholesky_val]
672+
673+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
662674
inputs_or_none = _find_diag_from_eye_mul(inputs)
663675
if inputs_or_none is None:
664676
return None
@@ -669,6 +681,13 @@ def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
669681
if len(non_eye_inputs) != 1:
670682
return None
671683

672-
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
684+
non_eye_input = non_eye_inputs[0]
673685

674-
return [eye_input * (non_eye_input**0.5)]
686+
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
688+
if non_eye_input.type.broadcastable[-2:] == (False, False):
689+
# For Matrix
690+
return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)]
691+
else:
692+
# For Vector or Scalar
693+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,8 @@ def test_cholesky_eye_rewrite():
599599

600600
@pytest.mark.parametrize(
601601
"shape",
602-
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
603-
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
602+
[(), (7,), (7, 7)],
603+
ids=["scalar", "vector", "matrix"],
604604
)
605605
def test_cholesky_diag_from_eye_mul(shape):
606606
# Initializing x based on scalar/vector/matrix
@@ -631,3 +631,28 @@ def test_cholesky_diag_from_eye_mul(shape):
631631
atol=1e-3 if config.floatX == "float32" else 1e-8,
632632
rtol=1e-3 if config.floatX == "float32" else 1e-8,
633633
)
634+
635+
636+
def test_cholesky_diag_from_diag():
637+
x = pt.dvector("x")
638+
x_diag = pt.diag(x)
639+
x_cholesky = pt.linalg.cholesky(x_diag)
640+
641+
# REWRITE TEST
642+
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
643+
nodes = f_rewritten.maker.fgraph.apply_nodes
644+
645+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
646+
647+
# NUMERIC VALUE TEST
648+
x_test = np.random.rand(10)
649+
x_test_matrix = np.eye(10) * x_test
650+
cholesky_val = np.linalg.cholesky(x_test_matrix)
651+
rewritten_cholesky = f_rewritten(x_test)
652+
653+
assert_allclose(
654+
cholesky_val,
655+
rewritten_cholesky,
656+
atol=1e-3 if config.floatX == "float32" else 1e-8,
657+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
658+
)

0 commit comments

Comments
 (0)