Skip to content

Commit 81ac247

Browse files
tanish1729jessegrabowski
authored andcommitted
fixed merge conflicts
1 parent 5632777 commit 81ac247

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,64 @@ def rewrite_slogdet_blockdiag(fgraph, node):
818818
)
819819

820820
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
821+
822+
823+
@register_canonicalize
824+
@register_stabilize
825+
@node_rewriter([Blockwise])
826+
def rewrite_cholesky_eye_to_eye(fgraph, node):
827+
"""
828+
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
829+
830+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
831+
832+
Parameters
833+
----------
834+
fgraph: FunctionGraph
835+
Function graph being optimized
836+
node: Apply
837+
Node of the function graph to be optimized
838+
839+
Returns
840+
-------
841+
list of Variable, optional
842+
List of optimized variables, or None if no optimization was performed
843+
"""
844+
# Find whether cholesky op is being applied
845+
if not isinstance(node.op.core_op, Cholesky):
846+
return None
847+
848+
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
849+
eye_check = node.inputs[0]
850+
if not (
851+
eye_check.owner
852+
and isinstance(eye_check.owner.op, Eye)
853+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
854+
):
855+
return None
856+
return [eye_check]
857+
858+
859+
@register_canonicalize
860+
@register_stabilize
861+
@node_rewriter([Blockwise])
862+
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
863+
# Find whether cholesky op is being applied
864+
if not isinstance(node.op.core_op, Cholesky):
865+
return None
866+
867+
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
868+
inputs = node.inputs[0]
869+
inputs_or_none = _find_diag_from_eye_mul(inputs)
870+
if inputs_or_none is None:
871+
return None
872+
873+
eye_input, non_eye_inputs = inputs_or_none
874+
875+
# Dealing with only one other input
876+
if len(non_eye_inputs) != 1:
877+
return None
878+
879+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
880+
881+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,3 +751,66 @@ def test_slogdet_blockdiag_rewrite():
751751
atol=1e-3 if config.floatX == "float32" else 1e-8,
752752
rtol=1e-3 if config.floatX == "float32" else 1e-8,
753753
)
754+
755+
756+
def test_cholesky_eye_rewrite():
757+
x = pt.eye(10)
758+
x_mat = pt.matrix("x")
759+
L = pt.linalg.cholesky(x)
760+
L_mat = pt.linalg.cholesky(x_mat)
761+
f_rewritten = function([], L, mode="FAST_RUN")
762+
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
763+
nodes = f_rewritten.maker.fgraph.apply_nodes
764+
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
765+
766+
# Rewrite Test
767+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
768+
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
769+
770+
# Value Test
771+
x_test = np.eye(10)
772+
L = np.linalg.cholesky(x_test)
773+
rewritten_val = f_rewritten()
774+
775+
assert_allclose(
776+
L,
777+
rewritten_val,
778+
atol=1e-3 if config.floatX == "float32" else 1e-8,
779+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
780+
)
781+
782+
783+
@pytest.mark.parametrize(
784+
"shape",
785+
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
786+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
787+
)
788+
def test_cholesky_diag_from_eye_mul(shape):
789+
# Initializing x based on scalar/vector/matrix
790+
x = pt.tensor("x", shape=shape)
791+
y = pt.eye(7) * x
792+
# Performing cholesky decomposition using pt.linalg.cholesky
793+
z_cholesky = pt.linalg.cholesky(y)
794+
795+
# REWRITE TEST
796+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
797+
nodes = f_rewritten.maker.fgraph.apply_nodes
798+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
799+
800+
# NUMERIC VALUE TEST
801+
if len(shape) == 0:
802+
x_test = np.array(np.random.rand()).astype(config.floatX)
803+
elif len(shape) == 1:
804+
x_test = np.random.rand(*shape).astype(config.floatX)
805+
else:
806+
x_test = np.random.rand(*shape).astype(config.floatX)
807+
x_test_matrix = np.eye(7) * x_test
808+
cholesky_val = np.linalg.cholesky(x_test_matrix)
809+
rewritten_val = f_rewritten(x_test)
810+
811+
assert_allclose(
812+
cholesky_val,
813+
rewritten_val,
814+
atol=1e-3 if config.floatX == "float32" else 1e-8,
815+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
816+
)

0 commit comments

Comments
 (0)