2
2
from collections .abc import Callable
3
3
from typing import cast
4
4
5
+ import pytensor .tensor as pt
5
6
from pytensor import Variable
6
7
from pytensor .graph import Apply , FunctionGraph
7
8
from pytensor .graph .rewriting .basic import (
@@ -605,7 +606,7 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
605
606
):
606
607
inv_input = inputs .owner .inputs [0 ]
607
608
if inv_input .type .ndim == 1 :
608
- inv_val = diagonal (1 / inv_input )
609
+ inv_val = pt . diag (1 / inv_input )
609
610
return [inv_val ]
610
611
611
612
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
@@ -619,21 +620,14 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
619
620
if len (non_eye_inputs ) != 1 :
620
621
return None
621
622
622
- eye_input , non_eye_input = eye_input [ 0 ], non_eye_inputs [0 ]
623
+ non_eye_input = non_eye_inputs [0 ]
623
624
624
625
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
625
626
if non_eye_input .type .broadcastable [- 2 :] == (False , False ):
626
627
# For Matrix
627
628
return [eye_input / non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 )]
628
- elif non_eye_input .type .broadcastable [- 2 :] == (True , True ):
629
- # For Scalar
630
- inv = eye_input / non_eye_input
631
- print (eye_input .type )
632
- print (non_eye_input .type )
633
- print (inv .type )
634
- return [eye_input / non_eye_input ]
635
629
else :
636
- # For Vector
630
+ # For Vector or Scalar
637
631
return [eye_input / non_eye_input ]
638
632
639
633
0 commit comments