|
32 | 32 | inv,
|
33 | 33 | kron,
|
34 | 34 | pinv,
|
| 35 | + slogdet, |
35 | 36 | svd,
|
36 | 37 | )
|
37 | 38 | from pytensor.tensor.rewriting.basic import (
|
@@ -710,39 +711,110 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
|
710 | 711 | @register_stabilize
|
711 | 712 | @node_rewriter([ExtractDiag])
|
712 | 713 | def rewrite_diag_blockdiag(fgraph, node):
|
| 714 | + """ |
| 715 | + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. |
| 716 | +
|
| 717 | + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) |
| 718 | +
|
| 719 | + Parameters |
| 720 | + ---------- |
| 721 | + fgraph: FunctionGraph |
| 722 | + Function graph being optimized |
| 723 | + node: Apply |
| 724 | + Node of the function graph to be optimized |
| 725 | +
|
| 726 | + Returns |
| 727 | + ------- |
| 728 | + list of Variable, optional |
| 729 | + List of optimized variables, or None if no optimization was performed |
| 730 | + """ |
713 | 731 | # Check for inner block_diag operation
|
714 |
| - potential_blockdiag = node.inputs[0].owner |
| 732 | + potential_block_diag = node.inputs[0].owner |
715 | 733 | if not (
|
716 |
| - potential_blockdiag |
717 |
| - and isinstance(potential_blockdiag.op, Blockwise) |
718 |
| - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) |
| 734 | + potential_block_diag |
| 735 | + and isinstance(potential_block_diag.op, Blockwise) |
| 736 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
719 | 737 | ):
|
720 | 738 | return None
|
721 | 739 |
|
722 | 740 | # Find the composing sub_matrices
|
723 |
| - submatrices = potential_blockdiag.inputs |
| 741 | + submatrices = potential_block_diag.inputs |
724 | 742 | submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
|
725 |
| - output = [concatenate(submatrices_diag)] |
726 | 743 |
|
727 |
| - return output |
| 744 | + return [concatenate(submatrices_diag)] |
728 | 745 |
|
729 | 746 |
|
730 | 747 | @register_canonicalize
|
731 | 748 | @register_stabilize
|
732 | 749 | @node_rewriter([det])
|
733 | 750 | def rewrite_det_blockdiag(fgraph, node):
|
| 751 | + """ |
| 752 | + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. |
| 753 | +
|
| 754 | + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) |
| 755 | +
|
| 756 | + Parameters |
| 757 | + ---------- |
| 758 | + fgraph: FunctionGraph |
| 759 | + Function graph being optimized |
| 760 | + node: Apply |
| 761 | + Node of the function graph to be optimized |
| 762 | +
|
| 763 | + Returns |
| 764 | + ------- |
| 765 | + list of Variable, optional |
| 766 | + List of optimized variables, or None if no optimization was performed |
| 767 | + """ |
734 | 768 | # Check for inner block_diag operation
|
735 |
| - potential_blockdiag = node.inputs[0].owner |
| 769 | + potential_block_diag = node.inputs[0].owner |
736 | 770 | if not (
|
737 |
| - potential_blockdiag |
738 |
| - and isinstance(potential_blockdiag.op, Blockwise) |
739 |
| - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) |
| 771 | + potential_block_diag |
| 772 | + and isinstance(potential_block_diag.op, Blockwise) |
| 773 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
740 | 774 | ):
|
741 | 775 | return None
|
742 | 776 |
|
743 | 777 | # Find the composing sub_matrices
|
744 |
| - sub_matrices = potential_blockdiag.inputs |
| 778 | + sub_matrices = potential_block_diag.inputs |
745 | 779 | det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
|
746 |
| - prod_det_sub_matrices = prod(det_sub_matrices) |
747 | 780 |
|
748 |
| - return [prod_det_sub_matrices] |
| 781 | + return [prod(det_sub_matrices)] |
| 782 | + |
| 783 | + |
| 784 | +@register_canonicalize |
| 785 | +@register_stabilize |
| 786 | +@node_rewriter([slogdet]) |
| 787 | +def rewrite_slogdet_blockdiag(fgraph, node): |
| 788 | + """ |
| 789 | + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 790 | +
|
| 791 | + slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
| 792 | +
|
| 793 | + Parameters |
| 794 | + ---------- |
| 795 | + fgraph: FunctionGraph |
| 796 | + Function graph being optimized |
| 797 | + node: Apply |
| 798 | + Node of the function graph to be optimized |
| 799 | +
|
| 800 | + Returns |
| 801 | + ------- |
| 802 | + list of Variable, optional |
| 803 | + List of optimized variables, or None if no optimization was performed |
| 804 | + """ |
| 805 | + # Check for inner block_diag operation |
| 806 | + potential_block_diag = node.inputs[0].owner |
| 807 | + if not ( |
| 808 | + potential_block_diag |
| 809 | + and isinstance(potential_block_diag.op, Blockwise) |
| 810 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 811 | + ): |
| 812 | + return None |
| 813 | + |
| 814 | + # Find the composing sub_matrices |
| 815 | + sub_matrices = potential_block_diag.inputs |
| 816 | + sign_sub_matrices, logdet_sub_matrices = zip( |
| 817 | + *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 818 | + ) |
| 819 | + |
| 820 | + return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
0 commit comments