Skip to content

Commit 1db9999

Browse files
committed
updated rewrites
1 parent 625e98c commit 1db9999

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

pytensor/tensor/rewriting/linalg.py

+94
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6+
from pytensor import tensor as pt
67
from pytensor.graph import Apply, FunctionGraph
78
from pytensor.graph.rewriting.basic import (
89
copy_stack_trace,
@@ -611,3 +612,96 @@ def rewrite_inv_inv(fgraph, node):
611612
):
612613
return None
613614
return [potential_inner_inv.inputs[0]]
615+
616+
617+
@register_canonicalize
618+
@register_stabilize
619+
@node_rewriter([Blockwise])
620+
def rewrite_inv_eye_to_eye(fgraph, node):
621+
"""
622+
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
623+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
624+
Parameters
625+
----------
626+
fgraph: FunctionGraph
627+
Function graph being optimized
628+
node: Apply
629+
Node of the function graph to be optimized
630+
Returns
631+
-------
632+
list of Variable, optional
633+
List of optimized variables, or None if no optimization was performed
634+
"""
635+
valid_inverses = (MatrixInverse, MatrixPinv)
636+
core_op = node.op.core_op
637+
if not (isinstance(core_op, valid_inverses)):
638+
return None
639+
640+
# Check whether input to inverse is Eye and the 1's are on main diagonal
641+
eye_check = node.inputs[0]
642+
if not (
643+
eye_check.owner
644+
and isinstance(eye_check.owner.op, Eye)
645+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
646+
):
647+
return None
648+
return [eye_check]
649+
650+
651+
@register_canonicalize
652+
@register_stabilize
653+
@node_rewriter([Blockwise])
654+
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
655+
"""
656+
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
657+
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
658+
659+
Parameters
660+
----------
661+
fgraph: FunctionGraph
662+
Function graph being optimized
663+
node: Apply
664+
Node of the function graph to be optimized
665+
666+
Returns
667+
-------
668+
list of Variable, optional
669+
List of optimized variables, or None if no optimization was performed
670+
"""
671+
valid_inverses = (MatrixInverse, MatrixPinv)
672+
core_op = node.op.core_op
673+
if not (isinstance(core_op, valid_inverses)):
674+
return None
675+
676+
inputs = node.inputs[0]
677+
# Check for use of pt.diag first
678+
if (
679+
inputs.owner
680+
and isinstance(inputs.owner.op, AllocDiag)
681+
and AllocDiag.is_offset_zero(inputs.owner)
682+
):
683+
inv_input = inputs.owner.inputs[0]
684+
if inv_input.type.ndim == 1:
685+
inv_val = pt.diag(1 / inv_input)
686+
return [inv_val]
687+
688+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
689+
inputs_or_none = _find_diag_from_eye_mul(inputs)
690+
if inputs_or_none is None:
691+
return None
692+
693+
eye_input, non_eye_inputs = inputs_or_none
694+
695+
# Dealing with only one other input
696+
if len(non_eye_inputs) != 1:
697+
return None
698+
699+
non_eye_input = non_eye_inputs[0]
700+
701+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
702+
if non_eye_input.type.broadcastable[-2:] == (False, False):
703+
# For Matrix
704+
return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)]
705+
else:
706+
# For Vector or Scalar
707+
return [eye_input / non_eye_input]

0 commit comments

Comments
 (0)