Skip to content

Commit 7074e50

Browse files
tanish1729jessegrabowski
authored andcommitted
fixed failing tests and added rewrite for pt.diag
1 parent 81ac247 commit 7074e50

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Callable
33
from typing import cast
44

5+
import pytensor.tensor as pt
56
from pytensor import Variable
67
from pytensor import tensor as pt
78
from pytensor.graph import Apply, FunctionGraph
@@ -859,13 +860,24 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
859860
@register_canonicalize
860861
@register_stabilize
861862
@node_rewriter([Blockwise])
862-
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
863+
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
863864
# Find whether cholesky op is being applied
864865
if not isinstance(node.op.core_op, Cholesky):
865866
return None
866867

867-
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
868868
inputs = node.inputs[0]
869+
# Check for use of pt.diag first
870+
if (
871+
inputs.owner
872+
and isinstance(inputs.owner.op, AllocDiag)
873+
and AllocDiag.is_offset_zero(inputs.owner)
874+
):
875+
cholesky_input = inputs.owner.inputs[0]
876+
if cholesky_input.type.ndim == 1:
877+
cholesky_val = pt.diag(cholesky_input**0.5)
878+
return [cholesky_val]
879+
880+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
869881
inputs_or_none = _find_diag_from_eye_mul(inputs)
870882
if inputs_or_none is None:
871883
return None
@@ -876,6 +888,13 @@ def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
876888
if len(non_eye_inputs) != 1:
877889
return None
878890

879-
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
891+
non_eye_input = non_eye_inputs[0]
880892

881-
return [eye_input * (non_eye_input**0.5)]
893+
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
894+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
895+
if non_eye_input.type.broadcastable[-2:] == (False, False):
896+
# For Matrix
897+
return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)]
898+
else:
899+
# For Vector or Scalar
900+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -782,8 +782,8 @@ def test_cholesky_eye_rewrite():
782782

783783
@pytest.mark.parametrize(
784784
"shape",
785-
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
786-
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
785+
[(), (7,), (7, 7)],
786+
ids=["scalar", "vector", "matrix"],
787787
)
788788
def test_cholesky_diag_from_eye_mul(shape):
789789
# Initializing x based on scalar/vector/matrix
@@ -814,3 +814,28 @@ def test_cholesky_diag_from_eye_mul(shape):
814814
atol=1e-3 if config.floatX == "float32" else 1e-8,
815815
rtol=1e-3 if config.floatX == "float32" else 1e-8,
816816
)
817+
818+
819+
def test_cholesky_diag_from_diag():
820+
x = pt.dvector("x")
821+
x_diag = pt.diag(x)
822+
x_cholesky = pt.linalg.cholesky(x_diag)
823+
824+
# REWRITE TEST
825+
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
826+
nodes = f_rewritten.maker.fgraph.apply_nodes
827+
828+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
829+
830+
# NUMERIC VALUE TEST
831+
x_test = np.random.rand(10)
832+
x_test_matrix = np.eye(10) * x_test
833+
cholesky_val = np.linalg.cholesky(x_test_matrix)
834+
rewritten_cholesky = f_rewritten(x_test)
835+
836+
assert_allclose(
837+
cholesky_val,
838+
rewritten_cholesky,
839+
atol=1e-3 if config.floatX == "float32" else 1e-8,
840+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
841+
)

0 commit comments

Comments
 (0)