Skip to content

Commit 1d62dd3

Browse files
committed
removed rewrite for solve with eye
1 parent e328e79 commit 1d62dd3

File tree

2 files changed

+7
-55
lines changed

2 files changed

+7
-55
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6-
from pytensor.graph import Apply, Constant, FunctionGraph
6+
from pytensor.graph import Apply, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
88
copy_stack_trace,
99
node_rewriter,
@@ -40,7 +40,6 @@
4040
Cholesky,
4141
Solve,
4242
SolveBase,
43-
SolveTriangular,
4443
block_diag,
4544
cholesky,
4645
solve,
@@ -572,41 +571,14 @@ def svd_uv_merge(fgraph, node):
572571
return [cl.outputs[1]]
573572

574573

575-
def _find_solve_with_eye(node) -> bool:
576-
"""
577-
The result of solve(A, b) is the solution x to the linear equation Ax = b. If b is an identity matrix (Eye), x is simply inv(A).
578-
Here, we are just recognising whether the solve operation returns an inverse or not; not replacing it because solve is mathematically more stable than inv.
579-
"""
580-
valid_solves = (Solve, SolveTriangular)
581-
# First, we verify whether we have a valid solve op
582-
if not (
583-
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_solves)
584-
):
585-
return False
586-
# If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
587-
solve_inputs = node.inputs
588-
eye_node = solve_inputs[1].owner
589-
590-
# We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
591-
if not (eye_node and isinstance(eye_node.op, Eye)):
592-
return False
593-
594-
if (
595-
isinstance(eye_node.inputs[-1], Constant)
596-
and eye_node.inputs[-1].data.item() != 0
597-
):
598-
return False
599-
return True
600-
601-
602574
@register_canonicalize
603575
@register_stabilize
604576
@node_rewriter([Blockwise])
605577
def rewrite_inv_inv(fgraph, node):
606578
"""
607579
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
608580
609-
Here, we check for direct inverse operations (inv/pinv) and also solve operations (solve/solve_triangular) in the case when b = Eye. This allows any combination of these "inverse" nodes to be simply rewritten.
581+
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
610582
611583
Parameters
612584
----------
@@ -620,18 +592,13 @@ def rewrite_inv_inv(fgraph, node):
620592
list of Variable, optional
621593
List of optimized variables, or None if no optimization was performed
622594
"""
623-
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
624-
valid_solves = (Solve, SolveTriangular)
625-
# Check if its a valid inverse operation (either inv/pinv or if its solve, then b = eye)
595+
valid_inverses = (MatrixInverse, MatrixPinv)
596+
# Check if its a valid inverse operation (either inv/pinv)
626597
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
627-
# If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
628-
# If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
598+
# If the outer operation is not a valid inverse, we do not apply this rewrite
629599
if not isinstance(node.op.core_op, valid_inverses):
630600
return None
631601

632-
if isinstance(node.op.core_op, valid_solves) and not _find_solve_with_eye(node):
633-
return None
634-
635602
potential_inner_inv = node.inputs[0].owner
636603
if potential_inner_inv is None or potential_inner_inv.op is None:
637604
return None
@@ -644,17 +611,4 @@ def rewrite_inv_inv(fgraph, node):
644611
):
645612
return None
646613

647-
# Similar to the check for outer operation, we now run the same checks for the inner op.
648-
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
649-
if not (
650-
isinstance(potential_inner_inv.op, Blockwise)
651-
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
652-
):
653-
return None
654-
655-
if isinstance(
656-
potential_inner_inv.op.core_op, valid_solves
657-
) and not _find_solve_with_eye(potential_inner_inv):
658-
return None
659-
660614
return [potential_inner_inv.inputs[0]]

tests/tensor/rewriting/test_linalg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,10 @@ def test_svd_uv_merge():
557557
assert svd_counter == 1
558558

559559

560-
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv", "solve", "solve_triangular"])
561-
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv", "solve", "solve_triangular"])
560+
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
561+
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
562562
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
563563
def get_pt_function(x, op_name):
564-
if "solve" in op_name:
565-
return getattr(pt.linalg, op_name)(x, pt.eye(x.shape[0]))
566564
return getattr(pt.linalg, op_name)(x)
567565

568566
x = pt.matrix("x")

0 commit comments

Comments
 (0)