|
3 | 3 | from typing import cast
|
4 | 4 |
|
5 | 5 | from pytensor import Variable
|
6 |
| -from pytensor.graph import Apply, FunctionGraph |
| 6 | +from pytensor.graph import Apply, Constant, FunctionGraph |
7 | 7 | from pytensor.graph.rewriting.basic import (
|
8 | 8 | copy_stack_trace,
|
9 | 9 | node_rewriter,
|
@@ -585,13 +585,16 @@ def _find_solve_with_eye(node) -> bool:
|
585 | 585 | return False
|
586 | 586 | # If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
|
587 | 587 | solve_inputs = node.inputs
|
588 |
| - eye_input = solve_inputs[1].owner |
| 588 | + eye_node = solve_inputs[1].owner |
589 | 589 |
|
590 | 590 | # We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
|
591 |
| - if not (eye_input and isinstance(eye_input.op, Eye)): |
| 591 | + if not (eye_node and isinstance(eye_node.op, Eye)): |
592 | 592 | return False
|
593 | 593 |
|
594 |
| - if eye_input.inputs[-1].data.item() != 0: |
| 594 | + if ( |
| 595 | + isinstance(eye_node.inputs[-1], Constant) |
| 596 | + and eye_node.inputs[-1].data.item() != 0 |
| 597 | + ): |
595 | 598 | return False
|
596 | 599 | return True
|
597 | 600 |
|
@@ -623,37 +626,35 @@ def rewrite_inv_inv(fgraph, node):
|
623 | 626 | # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
|
624 | 627 | # If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
|
625 | 628 | # If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
|
626 |
| - inv_check = False |
627 |
| - if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses): |
628 |
| - inv_check = True |
629 |
| - if isinstance(node.op.core_op, valid_solves): |
630 |
| - inv_check = _find_solve_with_eye(node) |
| 629 | + if not isinstance(node.op.core_op, valid_inverses): |
| 630 | + return None |
631 | 631 |
|
632 |
| - if not inv_check: |
| 632 | + if isinstance(node.op.core_op, valid_solves) and not _find_solve_with_eye(node): |
633 | 633 | return None
|
634 | 634 |
|
635 | 635 | potential_inner_inv = node.inputs[0].owner
|
636 | 636 | if potential_inner_inv is None or potential_inner_inv.op is None:
|
637 | 637 | return None
|
638 | 638 |
|
639 |
| - # Similar to the check for outer operation, we now run the same checks for the inner op. |
640 |
| - # If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None. |
641 |
| - inv_check_inner = False |
642 |
| - if isinstance(potential_inner_inv.op, Blockwise) and isinstance( |
643 |
| - potential_inner_inv.op.core_op, valid_inverses |
644 |
| - ): |
645 |
| - inv_check_inner = True |
646 |
| - if isinstance(potential_inner_inv.op.core_op, valid_solves): |
647 |
| - inv_check_inner = _find_solve_with_eye(potential_inner_inv) |
648 |
| - |
649 |
| - if not inv_check_inner: |
650 |
| - return None |
651 |
| - |
| 639 | + # Check if inner op is blockwise and and possible inv |
652 | 640 | if not (
|
653 | 641 | potential_inner_inv
|
654 | 642 | and isinstance(potential_inner_inv.op, Blockwise)
|
655 | 643 | and isinstance(node.op.core_op, valid_inverses)
|
656 | 644 | ):
|
657 | 645 | return None
|
658 | 646 |
|
| 647 | + # Similar to the check for outer operation, we now run the same checks for the inner op. |
| 648 | + # If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None. |
| 649 | + if not ( |
| 650 | + isinstance(potential_inner_inv.op, Blockwise) |
| 651 | + and isinstance(potential_inner_inv.op.core_op, valid_inverses) |
| 652 | + ): |
| 653 | + return None |
| 654 | + |
| 655 | + if isinstance( |
| 656 | + potential_inner_inv.op.core_op, valid_solves |
| 657 | + ) and not _find_solve_with_eye(potential_inner_inv): |
| 658 | + return None |
| 659 | + |
659 | 660 | return [potential_inner_inv.inputs[0]]
|
0 commit comments