Skip to content

Commit 4e88e29

Browse files
tanish1729jessegrabowski
authored andcommitted
fixed merge conflicts
1 parent 3e98b9f commit 4e88e29

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,64 @@ def rewrite_slogdet_kronecker(fgraph, node):
887887
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
888888

889889
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
890+
891+
892+
@register_canonicalize
893+
@register_stabilize
894+
@node_rewriter([Blockwise])
895+
def rewrite_cholesky_eye_to_eye(fgraph, node):
896+
"""
897+
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
898+
899+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
900+
901+
Parameters
902+
----------
903+
fgraph: FunctionGraph
904+
Function graph being optimized
905+
node: Apply
906+
Node of the function graph to be optimized
907+
908+
Returns
909+
-------
910+
list of Variable, optional
911+
List of optimized variables, or None if no optimization was performed
912+
"""
913+
# Find whether cholesky op is being applied
914+
if not isinstance(node.op.core_op, Cholesky):
915+
return None
916+
917+
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
918+
eye_check = node.inputs[0]
919+
if not (
920+
eye_check.owner
921+
and isinstance(eye_check.owner.op, Eye)
922+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
923+
):
924+
return None
925+
return [eye_check]
926+
927+
928+
@register_canonicalize
929+
@register_stabilize
930+
@node_rewriter([Blockwise])
931+
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
932+
# Find whether cholesky op is being applied
933+
if not isinstance(node.op.core_op, Cholesky):
934+
return None
935+
936+
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
937+
inputs = node.inputs[0]
938+
inputs_or_none = _find_diag_from_eye_mul(inputs)
939+
if inputs_or_none is None:
940+
return None
941+
942+
eye_input, non_eye_inputs = inputs_or_none
943+
944+
# Dealing with only one other input
945+
if len(non_eye_inputs) != 1:
946+
return None
947+
948+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
949+
950+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,66 @@ def test_slogdet_kronecker_rewrite():
803803
atol=1e-3 if config.floatX == "float32" else 1e-8,
804804
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805805
)
806+
807+
808+
def test_cholesky_eye_rewrite():
809+
x = pt.eye(10)
810+
x_mat = pt.matrix("x")
811+
L = pt.linalg.cholesky(x)
812+
L_mat = pt.linalg.cholesky(x_mat)
813+
f_rewritten = function([], L, mode="FAST_RUN")
814+
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
815+
nodes = f_rewritten.maker.fgraph.apply_nodes
816+
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
817+
818+
# Rewrite Test
819+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
820+
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
821+
822+
# Value Test
823+
x_test = np.eye(10)
824+
L = np.linalg.cholesky(x_test)
825+
rewritten_val = f_rewritten()
826+
827+
assert_allclose(
828+
L,
829+
rewritten_val,
830+
atol=1e-3 if config.floatX == "float32" else 1e-8,
831+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
832+
)
833+
834+
835+
@pytest.mark.parametrize(
836+
"shape",
837+
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
838+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
839+
)
840+
def test_cholesky_diag_from_eye_mul(shape):
841+
# Initializing x based on scalar/vector/matrix
842+
x = pt.tensor("x", shape=shape)
843+
y = pt.eye(7) * x
844+
# Performing cholesky decomposition using pt.linalg.cholesky
845+
z_cholesky = pt.linalg.cholesky(y)
846+
847+
# REWRITE TEST
848+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
849+
nodes = f_rewritten.maker.fgraph.apply_nodes
850+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
851+
852+
# NUMERIC VALUE TEST
853+
if len(shape) == 0:
854+
x_test = np.array(np.random.rand()).astype(config.floatX)
855+
elif len(shape) == 1:
856+
x_test = np.random.rand(*shape).astype(config.floatX)
857+
else:
858+
x_test = np.random.rand(*shape).astype(config.floatX)
859+
x_test_matrix = np.eye(7) * x_test
860+
cholesky_val = np.linalg.cholesky(x_test_matrix)
861+
rewritten_val = f_rewritten(x_test)
862+
863+
assert_allclose(
864+
cholesky_val,
865+
rewritten_val,
866+
atol=1e-3 if config.floatX == "float32" else 1e-8,
867+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
868+
)

0 commit comments

Comments
 (0)