Skip to content

Commit 96c9cbf

Browse files
committed
minor changes
1 parent f2fff31 commit 96c9cbf

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pytensor/tensor/rewriting/linalg.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def rewrite_inv_for_diag_eye_mul(fgraph, node):
561561
list of Variable, optional
562562
List of optimized variables, or None if no optimization was performed
563563
"""
564-
# List of useful operations : Inv, Pinv
565564
valid_inverses = (MatrixInverse, MatrixPinv)
566565
core_op = node.op.core_op
567566
if not (isinstance(core_op, valid_inverses)):
@@ -579,15 +578,15 @@ def rewrite_inv_for_diag_eye_mul(fgraph, node):
579578
if len(non_eye_inputs) != 1:
580579
return None
581580

582-
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
581+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
583582

584-
# For a matrix, we can first get the diagonal and then only use those
585-
if useful_non_eye.type.broadcastable[-2:] == (False, False):
583+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
584+
if non_eye_input.type.broadcastable[-2:] == (False, False):
586585
# For Matrix
587-
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)]
586+
return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)]
588587
else:
589588
# For Scalar/Vector
590-
return [useful_eye * 1 / useful_non_eye]
589+
return [eye_input / non_eye_input]
591590

592591

593592
def rewrite_inv_for_diag_ptdiag(fgraph, node):
@@ -626,7 +625,7 @@ def rewrite_inv_for_orthonormal(fgraph, node):
626625
):
627626
return None
628627

629-
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)
628+
# To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
630629
if input_to_inv == input_to_inv.owner.outputs[1]:
631630
return None
632631

tests/tensor/rewriting/test_linalg.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
from tests.test_rop import break_op
4141

4242

43+
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
44+
45+
4346
def test_rop_lop():
4447
mx = matrix("mx")
4548
mv = matrix("mv")
@@ -580,8 +583,8 @@ def test_inv_diag_from_eye_mul(shape):
580583
assert_allclose(
581584
inverse_matrix,
582585
rewritten_inverse,
583-
atol=1e-3 if config.floatX == "float32" else 1e-8,
584-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
586+
atol=ATOL,
587+
rtol=RTOL,
585588
)
586589

587590

@@ -590,7 +593,7 @@ def test_inv_orthonormal():
590593
u, s, vh = pt.linalg.svd(x)
591594
# Calculating inverse using pt.linalg.inv
592595
u_inv = pt.linalg.inv(u)
593-
print(u_inv.dprint())
596+
594597
# REWRITE TEST
595598
f_rewritten = function([x], u_inv, mode="FAST_RUN")
596599
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -607,6 +610,6 @@ def test_inv_orthonormal():
607610
assert_allclose(
608611
inverse_matrix,
609612
rewritten_inverse,
610-
atol=1e-3 if config.floatX == "float32" else 1e-8,
611-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
613+
atol=ATOL,
614+
rtol=RTOL,
612615
)

0 commit comments

Comments
 (0)