Skip to content

Commit d5d48fc

Browse files
committed
removed orthonormal rewrites
1 parent 317accf commit d5d48fc

File tree

2 files changed

+0
-66
lines changed

2 files changed

+0
-66
lines changed

pytensor/tensor/rewriting/linalg.py

-39
Original file line numberDiff line numberDiff line change
@@ -663,42 +663,3 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
663663
else:
664664
# For Vector or Scalar
665665
return [eye_input / non_eye_input]
666-
667-
668-
# @register_canonicalize
669-
# @register_stabilize
670-
# @node_rewriter([Blockwise])
671-
# def rewrite_inv_for_orthonormal(fgraph, node):
672-
# """
673-
# This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
674-
# This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
675-
676-
# Parameters
677-
# ----------
678-
# fgraph: FunctionGraph
679-
# Function graph being optimized
680-
# node: Apply
681-
# Node of the function graph to be optimized
682-
683-
# Returns
684-
# -------
685-
# list of Variable, optional
686-
# List of optimized variables, or None if no optimization was performed
687-
# """
688-
# # Dealing with orthonormal matrix from SVD
689-
# # Check if input to Inverse is coming from SVD
690-
# input_to_inv = node.inputs[0]
691-
# # Check if this input is coming from SVD with compute_uv = True
692-
# if not (
693-
# input_to_inv.owner
694-
# and isinstance(input_to_inv.owner.op, Blockwise)
695-
# and isinstance(input_to_inv.owner.op.core_op, SVD)
696-
# and input_to_inv.owner.op.core_op.compute_uv is True
697-
# ):
698-
# return None
699-
700-
# # To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
701-
# if input_to_inv == input_to_inv.owner.outputs[1]:
702-
# return None
703-
704-
# return [input_to_inv.T]

tests/tensor/rewriting/test_linalg.py

-27
Original file line numberDiff line numberDiff line change
@@ -644,30 +644,3 @@ def test_inv_diag_from_diag():
644644
atol=ATOL,
645645
rtol=RTOL,
646646
)
647-
648-
649-
def test_inv_orthonormal():
650-
x = pt.dmatrix("x")
651-
u, s, vh = pt.linalg.svd(x)
652-
# Calculating inverse using pt.linalg.inv
653-
u_inv = pt.linalg.inv(u)
654-
655-
# REWRITE TEST
656-
f_rewritten = function([x], u_inv, mode="FAST_RUN")
657-
nodes = f_rewritten.maker.fgraph.apply_nodes
658-
659-
valid_inverses = (MatrixInverse, MatrixPinv)
660-
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
661-
662-
# NUMERIC VALUE TEST
663-
x_test = np.random.rand(7, 7).astype(config.floatX)
664-
u_test, _, _ = np.linalg.svd(x_test)
665-
inverse_matrix = np.linalg.inv(u_test)
666-
rewritten_inverse = f_rewritten(x_test)
667-
668-
assert_allclose(
669-
inverse_matrix,
670-
rewritten_inverse,
671-
atol=ATOL,
672-
rtol=RTOL,
673-
)

0 commit comments

Comments
 (0)