You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
640
-
This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
641
-
642
-
Parameters
643
-
----------
644
-
fgraph: FunctionGraph
645
-
Function graph being optimized
646
-
node: Apply
647
-
Node of the function graph to be optimized
648
-
649
-
Returns
650
-
-------
651
-
list of Variable, optional
652
-
List of optimized variables, or None if no optimization was performed
653
-
"""
654
-
# Dealing with orthonormal matrix from SVD
655
-
# Check if input to Inverse is coming from SVD
656
-
input_to_inv=node.inputs[0]
657
-
# Check if this input is coming from SVD with compute_uv = True
658
-
ifnot (
659
-
input_to_inv.owner
660
-
andisinstance(input_to_inv.owner.op, Blockwise)
661
-
andisinstance(input_to_inv.owner.op.core_op, SVD)
662
-
andinput_to_inv.owner.op.core_op.compute_uvisTrue
663
-
):
664
-
returnNone
665
-
666
-
# To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
667
-
ifinput_to_inv==input_to_inv.owner.outputs[1]:
668
-
returnNone
669
-
670
-
return [input_to_inv.T]
668
+
# @register_canonicalize
669
+
# @register_stabilize
670
+
# @node_rewriter([Blockwise])
671
+
# def rewrite_inv_for_orthonormal(fgraph, node):
672
+
# """
673
+
# This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
674
+
# This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
675
+
676
+
# Parameters
677
+
# ----------
678
+
# fgraph: FunctionGraph
679
+
# Function graph being optimized
680
+
# node: Apply
681
+
# Node of the function graph to be optimized
682
+
683
+
# Returns
684
+
# -------
685
+
# list of Variable, optional
686
+
# List of optimized variables, or None if no optimization was performed
687
+
# """
688
+
# # Dealing with orthonormal matrix from SVD
689
+
# # Check if input to Inverse is coming from SVD
690
+
# input_to_inv = node.inputs[0]
691
+
# # Check if this input is coming from SVD with compute_uv = True
692
+
# if not (
693
+
# input_to_inv.owner
694
+
# and isinstance(input_to_inv.owner.op, Blockwise)
695
+
# and isinstance(input_to_inv.owner.op.core_op, SVD)
696
+
# and input_to_inv.owner.op.core_op.compute_uv is True
697
+
# ):
698
+
# return None
699
+
700
+
# # To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
701
+
# if input_to_inv == input_to_inv.owner.outputs[1]:
0 commit comments