Skip to content

Commit e6b6476

Browse files
committed
changed check for inverse and solve
1 parent dde7eb9 commit e6b6476

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -595,30 +595,31 @@ def _find_solve_with_eye(node):
595595
@register_stabilize
596596
@node_rewriter([Blockwise])
597597
def rewrite_inv_inv(fgraph, node):
598-
print(f"NODE - {node}")
599598
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
600599
valid_solves = (Solve, SolveTriangular)
601600
# Check if Solve has b = eye
602-
solve_inv_check = False
603-
if hasattr(node.op, "core_op") and isinstance(node.op.core_op, valid_solves):
604-
solve_inv_check = _find_solve_with_eye(node)
601+
inv_check = False
602+
if hasattr(node.op, "core_op") and isinstance(node.op.core_op, valid_inverses):
603+
inv_check = True
604+
if isinstance(node.op.core_op, valid_solves):
605+
inv_check = _find_solve_with_eye(node)
605606

606-
if not solve_inv_check:
607-
return None
608-
609-
if not (isinstance(node.op.core_op, valid_inverses)):
607+
if not inv_check:
610608
return None
611609

612610
potential_inner_inv = node.inputs[0].owner
613611
if potential_inner_inv is None or potential_inner_inv.op is None:
614612
return None
615613
# Check if its an inner solve as well, does that have b = eye
616-
solve_inv_check = False
614+
inv_check = False
617615
if hasattr(potential_inner_inv.op, "core_op") and isinstance(
618-
potential_inner_inv.op.core_op, valid_solves
616+
potential_inner_inv.op.core_op, valid_inverses
619617
):
620-
solve_inv_check = _find_solve_with_eye(potential_inner_inv)
621-
if not solve_inv_check:
618+
inv_check = True
619+
if isinstance(potential_inner_inv.op.core_op, valid_solves):
620+
inv_check = _find_solve_with_eye(potential_inner_inv)
621+
622+
if not inv_check:
622623
return None
623624

624625
if not (

tests/tensor/rewriting/test_linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def get_pt_function(x, op_name):
570570
op1 = get_pt_function(x, inv_op_1)
571571
op2 = get_pt_function(op1, inv_op_2)
572572
f_rewritten = function([x], op2, mode="FAST_RUN")
573+
print(f_rewritten.dprint())
573574
nodes = f_rewritten.maker.fgraph.apply_nodes
574575

575576
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)

0 commit comments

Comments
 (0)