Skip to content

Commit 317accf

Browse files
committed
Added rewrite for inv(eye) and removed (for now) orthonormal rewrites
1 parent 5fad484 commit 317accf

File tree

2 files changed

+94
-37
lines changed

2 files changed

+94
-37
lines changed

pytensor/tensor/rewriting/linalg.py

+71-37
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,40 @@ def svd_uv_merge(fgraph, node):
572572
return [cl.outputs[1]]
573573

574574

575+
@register_canonicalize
576+
@register_stabilize
577+
@node_rewriter([Blockwise])
578+
def rewrite_inv_eye_to_eye(fgraph, node):
579+
"""
580+
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
581+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
582+
Parameters
583+
----------
584+
fgraph: FunctionGraph
585+
Function graph being optimized
586+
node: Apply
587+
Node of the function graph to be optimized
588+
Returns
589+
-------
590+
list of Variable, optional
591+
List of optimized variables, or None if no optimization was performed
592+
"""
593+
valid_inverses = (MatrixInverse, MatrixPinv)
594+
core_op = node.op.core_op
595+
if not (isinstance(core_op, valid_inverses)):
596+
return None
597+
598+
# Check whether input to inverse is Eye and the 1's are on main diagonal
599+
eye_check = node.inputs[0]
600+
if not (
601+
eye_check.owner
602+
and isinstance(eye_check.owner.op, Eye)
603+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
604+
):
605+
return None
606+
return [eye_check]
607+
608+
575609
@register_canonicalize
576610
@register_stabilize
577611
@node_rewriter([Blockwise])
@@ -631,40 +665,40 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
631665
return [eye_input / non_eye_input]
632666

633667

634-
@register_canonicalize
635-
@register_stabilize
636-
@node_rewriter([Blockwise])
637-
def rewrite_inv_for_orthonormal(fgraph, node):
638-
"""
639-
This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
640-
This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
641-
642-
Parameters
643-
----------
644-
fgraph: FunctionGraph
645-
Function graph being optimized
646-
node: Apply
647-
Node of the function graph to be optimized
648-
649-
Returns
650-
-------
651-
list of Variable, optional
652-
List of optimized variables, or None if no optimization was performed
653-
"""
654-
# Dealing with orthonormal matrix from SVD
655-
# Check if input to Inverse is coming from SVD
656-
input_to_inv = node.inputs[0]
657-
# Check if this input is coming from SVD with compute_uv = True
658-
if not (
659-
input_to_inv.owner
660-
and isinstance(input_to_inv.owner.op, Blockwise)
661-
and isinstance(input_to_inv.owner.op.core_op, SVD)
662-
and input_to_inv.owner.op.core_op.compute_uv is True
663-
):
664-
return None
665-
666-
# To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
667-
if input_to_inv == input_to_inv.owner.outputs[1]:
668-
return None
669-
670-
return [input_to_inv.T]
668+
# @register_canonicalize
669+
# @register_stabilize
670+
# @node_rewriter([Blockwise])
671+
# def rewrite_inv_for_orthonormal(fgraph, node):
672+
# """
673+
# This rewrite takes advantage of the fact that for an orthonormal matrix, the inverse is simply the transpose.
674+
# This function deals with orthonormal matrix arising from pt.linalg.svd decomposition (U, Vh) or arising from pt.linalg.qr
675+
676+
# Parameters
677+
# ----------
678+
# fgraph: FunctionGraph
679+
# Function graph being optimized
680+
# node: Apply
681+
# Node of the function graph to be optimized
682+
683+
# Returns
684+
# -------
685+
# list of Variable, optional
686+
# List of optimized variables, or None if no optimization was performed
687+
# """
688+
# # Dealing with orthonormal matrix from SVD
689+
# # Check if input to Inverse is coming from SVD
690+
# input_to_inv = node.inputs[0]
691+
# # Check if this input is coming from SVD with compute_uv = True
692+
# if not (
693+
# input_to_inv.owner
694+
# and isinstance(input_to_inv.owner.op, Blockwise)
695+
# and isinstance(input_to_inv.owner.op.core_op, SVD)
696+
# and input_to_inv.owner.op.core_op.compute_uv is True
697+
# ):
698+
# return None
699+
700+
# # To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
701+
# if input_to_inv == input_to_inv.owner.outputs[1]:
702+
# return None
703+
704+
# return [input_to_inv.T]

tests/tensor/rewriting/test_linalg.py

+23
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,29 @@ def test_svd_uv_merge():
559559
assert svd_counter == 1
560560

561561

562+
def test_inv_eye_to_eye():
563+
x = pt.eye(10)
564+
x_inv = pt.linalg.inv(x)
565+
f_rewritten = function([], x_inv, mode="FAST_RUN")
566+
nodes = f_rewritten.maker.fgraph.apply_nodes
567+
568+
# Rewrite Test
569+
valid_inverses = (MatrixInverse, MatrixPinv)
570+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
571+
572+
# Value Test
573+
x_test = np.eye(10)
574+
x_inv_val = np.linalg.inv(x_test)
575+
rewritten_val = f_rewritten()
576+
577+
assert_allclose(
578+
x_inv_val,
579+
rewritten_val,
580+
atol=1e-3 if config.floatX == "float32" else 1e-8,
581+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
582+
)
583+
584+
562585
@pytest.mark.parametrize(
563586
"shape",
564587
[(), (7,), (7, 7)],

0 commit comments

Comments
 (0)