@@ -595,30 +595,31 @@ def _find_solve_with_eye(node):
595
595
@register_stabilize
596
596
@node_rewriter ([Blockwise ])
597
597
def rewrite_inv_inv (fgraph , node ):
598
- print (f"NODE - { node } " )
599
598
valid_inverses = (MatrixInverse , MatrixPinv , Solve , SolveTriangular )
600
599
valid_solves = (Solve , SolveTriangular )
601
600
# Check if Solve has b = eye
602
- solve_inv_check = False
603
- if hasattr (node .op , "core_op" ) and isinstance (node .op .core_op , valid_solves ):
604
- solve_inv_check = _find_solve_with_eye (node )
601
+ inv_check = False
602
+ if hasattr (node .op , "core_op" ) and isinstance (node .op .core_op , valid_inverses ):
603
+ inv_check = True
604
+ if isinstance (node .op .core_op , valid_solves ):
605
+ inv_check = _find_solve_with_eye (node )
605
606
606
- if not solve_inv_check :
607
- return None
608
-
609
- if not (isinstance (node .op .core_op , valid_inverses )):
607
+ if not inv_check :
610
608
return None
611
609
612
610
potential_inner_inv = node .inputs [0 ].owner
613
611
if potential_inner_inv is None or potential_inner_inv .op is None :
614
612
return None
615
613
# Check if its an inner solve as well, does that have b = eye
616
- solve_inv_check = False
614
+ inv_check = False
617
615
if hasattr (potential_inner_inv .op , "core_op" ) and isinstance (
618
- potential_inner_inv .op .core_op , valid_solves
616
+ potential_inner_inv .op .core_op , valid_inverses
619
617
):
620
- solve_inv_check = _find_solve_with_eye (potential_inner_inv )
621
- if not solve_inv_check :
618
+ inv_check = True
619
+ if isinstance (potential_inner_inv .op .core_op , valid_solves ):
620
+ inv_check = _find_solve_with_eye (potential_inner_inv )
621
+
622
+ if not inv_check :
622
623
return None
623
624
624
625
if not (
0 commit comments