2
2
from collections .abc import Callable
3
3
from typing import cast
4
4
5
- import pytensor .tensor as pt
6
5
from pytensor import Variable
7
6
from pytensor import tensor as pt
8
7
from pytensor .graph import Apply , FunctionGraph
@@ -824,7 +823,7 @@ def rewrite_slogdet_blockdiag(fgraph, node):
824
823
@register_canonicalize
825
824
@register_stabilize
826
825
@node_rewriter ([Blockwise ])
827
- def rewrite_cholesky_eye_to_eye (fgraph , node ):
826
+ def rewrite_remove_useless_cholesky (fgraph , node ):
828
827
"""
829
828
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
830
829
@@ -847,14 +846,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
847
846
return None
848
847
849
848
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
850
- eye_check = node .inputs [0 ]
849
+ potential_eye = node .inputs [0 ]
851
850
if not (
852
- eye_check .owner
853
- and isinstance (eye_check .owner .op , Eye )
854
- and getattr (eye_check .owner .inputs [- 1 ], "data" , - 1 ).item () == 0
851
+ potential_eye .owner
852
+ and isinstance (potential_eye .owner .op , Eye )
853
+ and hasattr (potential_eye .owner .inputs [- 1 ], "data" )
854
+ and potential_eye .owner .inputs [- 1 ].data .item () == 0
855
855
):
856
856
return None
857
- return [eye_check ]
857
+ return [potential_eye ]
858
858
859
859
860
860
@register_canonicalize
@@ -872,10 +872,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
872
872
and isinstance (inputs .owner .op , AllocDiag )
873
873
and AllocDiag .is_offset_zero (inputs .owner )
874
874
):
875
- cholesky_input = inputs .owner .inputs [0 ]
876
- if cholesky_input .type .ndim == 1 :
877
- cholesky_val = pt .diag (cholesky_input ** 0.5 )
878
- return [cholesky_val ]
875
+ diag_input = inputs .owner .inputs [0 ]
876
+ cholesky_val = pt .diag (diag_input ** 0.5 )
877
+ return [cholesky_val ]
879
878
880
879
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
881
880
inputs_or_none = _find_diag_from_eye_mul (inputs )
@@ -893,8 +892,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
893
892
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
894
893
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
895
894
if non_eye_input .type .broadcastable [- 2 :] == (False , False ):
896
- # For Matrix
897
- return [eye_input * (non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 ) ** 0.5 )]
898
- else :
899
- # For Vector or Scalar
900
- return [eye_input * (non_eye_input ** 0.5 )]
895
+ non_eye_input = non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 )
896
+
897
+ return [eye_input * (non_eye_input ** 0.5 )]
0 commit comments