Skip to content

Commit 5fad484

Browse files
committed
rewrite deals with pt.diag as well now
1 parent 2f98fa2 commit 5fad484

File tree

2 files changed

+5
-14
lines changed

2 files changed

+5
-14
lines changed

pytensor/tensor/rewriting/linalg.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections.abc import Callable
33
from typing import cast
44

5+
import pytensor.tensor as pt
56
from pytensor import Variable
67
from pytensor.graph import Apply, FunctionGraph
78
from pytensor.graph.rewriting.basic import (
@@ -605,7 +606,7 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
605606
):
606607
inv_input = inputs.owner.inputs[0]
607608
if inv_input.type.ndim == 1:
608-
inv_val = diagonal(1 / inv_input)
609+
inv_val = pt.diag(1 / inv_input)
609610
return [inv_val]
610611

611612
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
@@ -619,21 +620,14 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
619620
if len(non_eye_inputs) != 1:
620621
return None
621622

622-
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
623+
non_eye_input = non_eye_inputs[0]
623624

624625
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
625626
if non_eye_input.type.broadcastable[-2:] == (False, False):
626627
# For Matrix
627628
return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)]
628-
elif non_eye_input.type.broadcastable[-2:] == (True, True):
629-
# For Scalar
630-
inv = eye_input / non_eye_input
631-
print(eye_input.type)
632-
print(non_eye_input.type)
633-
print(inv.type)
634-
return [eye_input / non_eye_input]
635629
else:
636-
# For Vector
630+
# For Vector or Scalar
637631
return [eye_input / non_eye_input]
638632

639633

tests/tensor/rewriting/test_linalg.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@
4343
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
4444

4545

46-
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
47-
48-
4946
def test_rop_lop():
5047
mx = matrix("mx")
5148
mv = matrix("mv")
@@ -616,7 +613,7 @@ def test_inv_diag_from_diag():
616613
x_test = np.random.rand(10)
617614
x_test_matrix = np.eye(10) * x_test
618615
inverse_matrix = np.linalg.inv(x_test_matrix)
619-
rewritten_inverse = f_rewritten(x_test_matrix)
616+
rewritten_inverse = f_rewritten(x_test)
620617

621618
assert_allclose(
622619
inverse_matrix,

0 commit comments

Comments
 (0)