Skip to content

Commit e328e79

Browse files
committed
simplifed logic for inv check
1 parent 5ee276a commit e328e79

File tree

2 files changed

+27
-32
lines changed

2 files changed

+27
-32
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6-
from pytensor.graph import Apply, FunctionGraph
6+
from pytensor.graph import Apply, Constant, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
88
copy_stack_trace,
99
node_rewriter,
@@ -585,13 +585,16 @@ def _find_solve_with_eye(node) -> bool:
585585
return False
586586
# If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
587587
solve_inputs = node.inputs
588-
eye_input = solve_inputs[1].owner
588+
eye_node = solve_inputs[1].owner
589589

590590
# We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
591-
if not (eye_input and isinstance(eye_input.op, Eye)):
591+
if not (eye_node and isinstance(eye_node.op, Eye)):
592592
return False
593593

594-
if eye_input.inputs[-1].data.item() != 0:
594+
if (
595+
isinstance(eye_node.inputs[-1], Constant)
596+
and eye_node.inputs[-1].data.item() != 0
597+
):
595598
return False
596599
return True
597600

@@ -623,37 +626,35 @@ def rewrite_inv_inv(fgraph, node):
623626
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
624627
# If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
625628
# If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
626-
inv_check = False
627-
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses):
628-
inv_check = True
629-
if isinstance(node.op.core_op, valid_solves):
630-
inv_check = _find_solve_with_eye(node)
629+
if not isinstance(node.op.core_op, valid_inverses):
630+
return None
631631

632-
if not inv_check:
632+
if isinstance(node.op.core_op, valid_solves) and not _find_solve_with_eye(node):
633633
return None
634634

635635
potential_inner_inv = node.inputs[0].owner
636636
if potential_inner_inv is None or potential_inner_inv.op is None:
637637
return None
638638

639-
# Similar to the check for outer operation, we now run the same checks for the inner op.
640-
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
641-
inv_check_inner = False
642-
if isinstance(potential_inner_inv.op, Blockwise) and isinstance(
643-
potential_inner_inv.op.core_op, valid_inverses
644-
):
645-
inv_check_inner = True
646-
if isinstance(potential_inner_inv.op.core_op, valid_solves):
647-
inv_check_inner = _find_solve_with_eye(potential_inner_inv)
648-
649-
if not inv_check_inner:
650-
return None
651-
639+
# Check if inner op is blockwise and and possible inv
652640
if not (
653641
potential_inner_inv
654642
and isinstance(potential_inner_inv.op, Blockwise)
655643
and isinstance(node.op.core_op, valid_inverses)
656644
):
657645
return None
658646

647+
# Similar to the check for outer operation, we now run the same checks for the inner op.
648+
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
649+
if not (
650+
isinstance(potential_inner_inv.op, Blockwise)
651+
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
652+
):
653+
return None
654+
655+
if isinstance(
656+
potential_inner_inv.op.core_op, valid_solves
657+
) and not _find_solve_with_eye(potential_inner_inv):
658+
return None
659+
659660
return [potential_inner_inv.inputs[0]]

tests/tensor/rewriting/test_linalg.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph.rewriting.utils import rewrite_graph
1314
from pytensor.tensor import swapaxes
1415
from pytensor.tensor.blockwise import Blockwise
1516
from pytensor.tensor.elemwise import DimShuffle
@@ -567,12 +568,5 @@ def get_pt_function(x, op_name):
567568
x = pt.matrix("x")
568569
op1 = get_pt_function(x, inv_op_1)
569570
op2 = get_pt_function(op1, inv_op_2)
570-
f_rewritten = function([x], op2, mode="FAST_RUN")
571-
print(f_rewritten.dprint())
572-
nodes = f_rewritten.maker.fgraph.apply_nodes
573-
574-
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
575-
576-
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
577-
x_testing = np.random.rand(10, 10).astype(config.floatX)
578-
np.testing.assert_allclose(f_rewritten(x_testing), x_testing)
571+
rewritten_out = rewrite_graph(op2)
572+
assert rewritten_out == x

0 commit comments

Comments
 (0)