|
12 | 12 | from pytensor.scalar.basic import Mul
|
13 | 13 | from pytensor.tensor.basic import (
|
14 | 14 | AllocDiag,
|
| 15 | + ExtractDiag, |
15 | 16 | Eye,
|
16 | 17 | TensorVariable,
|
| 18 | + concatenate, |
| 19 | + diag, |
17 | 20 | diagonal,
|
18 | 21 | )
|
19 | 22 | from pytensor.tensor.blas import Dot22
|
|
29 | 32 | inv,
|
30 | 33 | kron,
|
31 | 34 | pinv,
|
| 35 | + slogdet, |
32 | 36 | svd,
|
33 | 37 | )
|
34 | 38 | from pytensor.tensor.rewriting.basic import (
|
@@ -701,3 +705,116 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
|
701 | 705 | non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
|
702 | 706 |
|
703 | 707 | return [eye_input / non_eye_input]
|
| 708 | + |
| 709 | + |
| 710 | +@register_canonicalize |
| 711 | +@register_stabilize |
| 712 | +@node_rewriter([ExtractDiag]) |
| 713 | +def rewrite_diag_blockdiag(fgraph, node): |
| 714 | + """ |
| 715 | + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. |
| 716 | +
|
| 717 | + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) |
| 718 | +
|
| 719 | + Parameters |
| 720 | + ---------- |
| 721 | + fgraph: FunctionGraph |
| 722 | + Function graph being optimized |
| 723 | + node: Apply |
| 724 | + Node of the function graph to be optimized |
| 725 | +
|
| 726 | + Returns |
| 727 | + ------- |
| 728 | + list of Variable, optional |
| 729 | + List of optimized variables, or None if no optimization was performed |
| 730 | + """ |
| 731 | + # Check for inner block_diag operation |
| 732 | + potential_block_diag = node.inputs[0].owner |
| 733 | + if not ( |
| 734 | + potential_block_diag |
| 735 | + and isinstance(potential_block_diag.op, Blockwise) |
| 736 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 737 | + ): |
| 738 | + return None |
| 739 | + |
| 740 | + # Find the composing sub_matrices |
| 741 | + submatrices = potential_block_diag.inputs |
| 742 | + submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] |
| 743 | + |
| 744 | + return [concatenate(submatrices_diag)] |
| 745 | + |
| 746 | + |
| 747 | +@register_canonicalize |
| 748 | +@register_stabilize |
| 749 | +@node_rewriter([det]) |
| 750 | +def rewrite_det_blockdiag(fgraph, node): |
| 751 | + """ |
| 752 | + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. |
| 753 | +
|
| 754 | + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) |
| 755 | +
|
| 756 | + Parameters |
| 757 | + ---------- |
| 758 | + fgraph: FunctionGraph |
| 759 | + Function graph being optimized |
| 760 | + node: Apply |
| 761 | + Node of the function graph to be optimized |
| 762 | +
|
| 763 | + Returns |
| 764 | + ------- |
| 765 | + list of Variable, optional |
| 766 | + List of optimized variables, or None if no optimization was performed |
| 767 | + """ |
| 768 | + # Check for inner block_diag operation |
| 769 | + potential_block_diag = node.inputs[0].owner |
| 770 | + if not ( |
| 771 | + potential_block_diag |
| 772 | + and isinstance(potential_block_diag.op, Blockwise) |
| 773 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 774 | + ): |
| 775 | + return None |
| 776 | + |
| 777 | + # Find the composing sub_matrices |
| 778 | + sub_matrices = potential_block_diag.inputs |
| 779 | + det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 780 | + |
| 781 | + return [prod(det_sub_matrices)] |
| 782 | + |
| 783 | + |
| 784 | +@register_canonicalize |
| 785 | +@register_stabilize |
| 786 | +@node_rewriter([slogdet]) |
| 787 | +def rewrite_slogdet_blockdiag(fgraph, node): |
| 788 | + """ |
| 789 | + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 790 | +
|
| 791 | + slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
| 792 | +
|
| 793 | + Parameters |
| 794 | + ---------- |
| 795 | + fgraph: FunctionGraph |
| 796 | + Function graph being optimized |
| 797 | + node: Apply |
| 798 | + Node of the function graph to be optimized |
| 799 | +
|
| 800 | + Returns |
| 801 | + ------- |
| 802 | + list of Variable, optional |
| 803 | + List of optimized variables, or None if no optimization was performed |
| 804 | + """ |
| 805 | + # Check for inner block_diag operation |
| 806 | + potential_block_diag = node.inputs[0].owner |
| 807 | + if not ( |
| 808 | + potential_block_diag |
| 809 | + and isinstance(potential_block_diag.op, Blockwise) |
| 810 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 811 | + ): |
| 812 | + return None |
| 813 | + |
| 814 | + # Find the composing sub_matrices |
| 815 | + sub_matrices = potential_block_diag.inputs |
| 816 | + sign_sub_matrices, logdet_sub_matrices = zip( |
| 817 | + *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 818 | + ) |
| 819 | + |
| 820 | + return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
0 commit comments