2
2
from collections .abc import Callable
3
3
from typing import cast
4
4
5
+ import pytensor .tensor as pt
5
6
from pytensor import Variable
6
7
from pytensor .graph import Apply , FunctionGraph
7
8
from pytensor .graph .rewriting .basic import (
@@ -652,13 +653,24 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
652
653
@register_canonicalize
653
654
@register_stabilize
654
655
@node_rewriter ([Blockwise ])
655
- def rewrite_cholesky_diag_from_eye_mul (fgraph , node ):
656
+ def rewrite_cholesky_diag_to_sqrt_diag (fgraph , node ):
656
657
# Find whether cholesky op is being applied
657
658
if not isinstance (node .op .core_op , Cholesky ):
658
659
return None
659
660
660
- # Check whether input is diagonal from multiplcation of identity matrix with a tensor
661
661
inputs = node .inputs [0 ]
662
+ # Check for use of pt.diag first
663
+ if (
664
+ inputs .owner
665
+ and isinstance (inputs .owner .op , AllocDiag )
666
+ and AllocDiag .is_offset_zero (inputs .owner )
667
+ ):
668
+ cholesky_input = inputs .owner .inputs [0 ]
669
+ if cholesky_input .type .ndim == 1 :
670
+ cholesky_val = pt .diag (cholesky_input ** 0.5 )
671
+ return [cholesky_val ]
672
+
673
+ # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
662
674
inputs_or_none = _find_diag_from_eye_mul (inputs )
663
675
if inputs_or_none is None :
664
676
return None
@@ -669,6 +681,13 @@ def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
669
681
if len (non_eye_inputs ) != 1 :
670
682
return None
671
683
672
- eye_input , non_eye_input = eye_input [ 0 ], non_eye_inputs [0 ]
684
+ non_eye_input = non_eye_inputs [0 ]
673
685
674
- return [eye_input * (non_eye_input ** 0.5 )]
686
+ # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687
+ # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
688
+ if non_eye_input .type .broadcastable [- 2 :] == (False , False ):
689
+ # For Matrix
690
+ return [eye_input * (non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 ) ** 0.5 )]
691
+ else :
692
+ # For Vector or Scalar
693
+ return [eye_input * (non_eye_input ** 0.5 )]
0 commit comments