|
31 | 31 | inv,
|
32 | 32 | kron,
|
33 | 33 | pinv,
|
| 34 | + slogdet, |
34 | 35 | svd,
|
35 | 36 | )
|
36 | 37 | from pytensor.tensor.rewriting.basic import (
|
@@ -620,39 +621,110 @@ def rewrite_inv_inv(fgraph, node):
|
620 | 621 | @register_stabilize
|
621 | 622 | @node_rewriter([ExtractDiag])
|
622 | 623 | def rewrite_diag_blockdiag(fgraph, node):
|
| 624 | + """ |
| 625 | + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. |
| 626 | +
|
| 627 | + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) |
| 628 | +
|
| 629 | + Parameters |
| 630 | + ---------- |
| 631 | + fgraph: FunctionGraph |
| 632 | + Function graph being optimized |
| 633 | + node: Apply |
| 634 | + Node of the function graph to be optimized |
| 635 | +
|
| 636 | + Returns |
| 637 | + ------- |
| 638 | + list of Variable, optional |
| 639 | + List of optimized variables, or None if no optimization was performed |
| 640 | + """ |
623 | 641 | # Check for inner block_diag operation
|
624 |
| - potential_blockdiag = node.inputs[0].owner |
| 642 | + potential_block_diag = node.inputs[0].owner |
625 | 643 | if not (
|
626 |
| - potential_blockdiag |
627 |
| - and isinstance(potential_blockdiag.op, Blockwise) |
628 |
| - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) |
| 644 | + potential_block_diag |
| 645 | + and isinstance(potential_block_diag.op, Blockwise) |
| 646 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
629 | 647 | ):
|
630 | 648 | return None
|
631 | 649 |
|
632 | 650 | # Find the composing sub_matrices
|
633 |
| - submatrices = potential_blockdiag.inputs |
| 651 | + submatrices = potential_block_diag.inputs |
634 | 652 | submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))]
|
635 |
| - output = [concatenate(submatrices_diag)] |
636 | 653 |
|
637 |
| - return output |
| 654 | + return [concatenate(submatrices_diag)] |
638 | 655 |
|
639 | 656 |
|
640 | 657 | @register_canonicalize
|
641 | 658 | @register_stabilize
|
642 | 659 | @node_rewriter([det])
|
643 | 660 | def rewrite_det_blockdiag(fgraph, node):
|
| 661 | + """ |
| 662 | + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. |
| 663 | +
|
| 664 | + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) |
| 665 | +
|
| 666 | + Parameters |
| 667 | + ---------- |
| 668 | + fgraph: FunctionGraph |
| 669 | + Function graph being optimized |
| 670 | + node: Apply |
| 671 | + Node of the function graph to be optimized |
| 672 | +
|
| 673 | + Returns |
| 674 | + ------- |
| 675 | + list of Variable, optional |
| 676 | + List of optimized variables, or None if no optimization was performed |
| 677 | + """ |
644 | 678 | # Check for inner block_diag operation
|
645 |
| - potential_blockdiag = node.inputs[0].owner |
| 679 | + potential_block_diag = node.inputs[0].owner |
646 | 680 | if not (
|
647 |
| - potential_blockdiag |
648 |
| - and isinstance(potential_blockdiag.op, Blockwise) |
649 |
| - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) |
| 681 | + potential_block_diag |
| 682 | + and isinstance(potential_block_diag.op, Blockwise) |
| 683 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
650 | 684 | ):
|
651 | 685 | return None
|
652 | 686 |
|
653 | 687 | # Find the composing sub_matrices
|
654 |
| - sub_matrices = potential_blockdiag.inputs |
| 688 | + sub_matrices = potential_block_diag.inputs |
655 | 689 | det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))]
|
656 |
| - prod_det_sub_matrices = prod(det_sub_matrices) |
657 | 690 |
|
658 |
| - return [prod_det_sub_matrices] |
| 691 | + return [prod(det_sub_matrices)] |
| 692 | + |
| 693 | + |
| 694 | +@register_canonicalize |
| 695 | +@register_stabilize |
| 696 | +@node_rewriter([slogdet]) |
| 697 | +def rewrite_slogdet_blockdiag(fgraph, node): |
| 698 | + """ |
| 699 | + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those |
| 700 | +
|
| 701 | + slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) |
| 702 | +
|
| 703 | + Parameters |
| 704 | + ---------- |
| 705 | + fgraph: FunctionGraph |
| 706 | + Function graph being optimized |
| 707 | + node: Apply |
| 708 | + Node of the function graph to be optimized |
| 709 | +
|
| 710 | + Returns |
| 711 | + ------- |
| 712 | + list of Variable, optional |
| 713 | + List of optimized variables, or None if no optimization was performed |
| 714 | + """ |
| 715 | + # Check for inner block_diag operation |
| 716 | + potential_block_diag = node.inputs[0].owner |
| 717 | + if not ( |
| 718 | + potential_block_diag |
| 719 | + and isinstance(potential_block_diag.op, Blockwise) |
| 720 | + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) |
| 721 | + ): |
| 722 | + return None |
| 723 | + |
| 724 | + # Find the composing sub_matrices |
| 725 | + sub_matrices = potential_block_diag.inputs |
| 726 | + sign_sub_matrices, logdet_sub_matrices = zip( |
| 727 | + *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] |
| 728 | + ) |
| 729 | + |
| 730 | + return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] |
0 commit comments