|
3 | 3 | from typing import cast
|
4 | 4 |
|
5 | 5 | from pytensor import Variable
|
| 6 | +from pytensor import tensor as pt |
6 | 7 | from pytensor.graph import Apply, FunctionGraph
|
7 | 8 | from pytensor.graph.rewriting.basic import (
|
8 | 9 | copy_stack_trace,
|
@@ -611,3 +612,96 @@ def rewrite_inv_inv(fgraph, node):
|
611 | 612 | ):
|
612 | 613 | return None
|
613 | 614 | return [potential_inner_inv.inputs[0]]
|
| 615 | + |
| 616 | + |
| 617 | +@register_canonicalize |
| 618 | +@register_stabilize |
| 619 | +@node_rewriter([Blockwise]) |
| 620 | +def rewrite_inv_eye_to_eye(fgraph, node): |
| 621 | + """ |
| 622 | + This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself |
| 623 | + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. |
| 624 | + Parameters |
| 625 | + ---------- |
| 626 | + fgraph: FunctionGraph |
| 627 | + Function graph being optimized |
| 628 | + node: Apply |
| 629 | + Node of the function graph to be optimized |
| 630 | + Returns |
| 631 | + ------- |
| 632 | + list of Variable, optional |
| 633 | + List of optimized variables, or None if no optimization was performed |
| 634 | + """ |
| 635 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 636 | + core_op = node.op.core_op |
| 637 | + if not (isinstance(core_op, valid_inverses)): |
| 638 | + return None |
| 639 | + |
| 640 | + # Check whether input to inverse is Eye and the 1's are on main diagonal |
| 641 | + eye_check = node.inputs[0] |
| 642 | + if not ( |
| 643 | + eye_check.owner |
| 644 | + and isinstance(eye_check.owner.op, Eye) |
| 645 | + and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 |
| 646 | + ): |
| 647 | + return None |
| 648 | + return [eye_check] |
| 649 | + |
| 650 | + |
| 651 | +@register_canonicalize |
| 652 | +@register_stabilize |
| 653 | +@node_rewriter([Blockwise]) |
| 654 | +def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): |
| 655 | + """ |
| 656 | + This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. |
| 657 | + This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix |
| 658 | +
|
| 659 | + Parameters |
| 660 | + ---------- |
| 661 | + fgraph: FunctionGraph |
| 662 | + Function graph being optimized |
| 663 | + node: Apply |
| 664 | + Node of the function graph to be optimized |
| 665 | +
|
| 666 | + Returns |
| 667 | + ------- |
| 668 | + list of Variable, optional |
| 669 | + List of optimized variables, or None if no optimization was performed |
| 670 | + """ |
| 671 | + valid_inverses = (MatrixInverse, MatrixPinv) |
| 672 | + core_op = node.op.core_op |
| 673 | + if not (isinstance(core_op, valid_inverses)): |
| 674 | + return None |
| 675 | + |
| 676 | + inputs = node.inputs[0] |
| 677 | + # Check for use of pt.diag first |
| 678 | + if ( |
| 679 | + inputs.owner |
| 680 | + and isinstance(inputs.owner.op, AllocDiag) |
| 681 | + and AllocDiag.is_offset_zero(inputs.owner) |
| 682 | + ): |
| 683 | + inv_input = inputs.owner.inputs[0] |
| 684 | + if inv_input.type.ndim == 1: |
| 685 | + inv_val = pt.diag(1 / inv_input) |
| 686 | + return [inv_val] |
| 687 | + |
| 688 | + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix |
| 689 | + inputs_or_none = _find_diag_from_eye_mul(inputs) |
| 690 | + if inputs_or_none is None: |
| 691 | + return None |
| 692 | + |
| 693 | + eye_input, non_eye_inputs = inputs_or_none |
| 694 | + |
| 695 | + # Dealing with only one other input |
| 696 | + if len(non_eye_inputs) != 1: |
| 697 | + return None |
| 698 | + |
| 699 | + non_eye_input = non_eye_inputs[0] |
| 700 | + |
| 701 | + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those |
| 702 | + if non_eye_input.type.broadcastable[-2:] == (False, False): |
| 703 | + # For Matrix |
| 704 | + return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)] |
| 705 | + else: |
| 706 | + # For Vector or Scalar |
| 707 | + return [eye_input / non_eye_input] |
0 commit comments