Skip to content

Commit 06204ec

Browse files
committed
minor changes
1 parent 182cb96 commit 06204ec

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

pytensor/tensor/rewriting/linalg.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,6 @@ def rewrite_inv_for_diag_eye_mul(fgraph, node):
591591
list of Variable, optional
592592
List of optimized variables, or None if no optimization was performed
593593
"""
594-
# List of useful operations : Inv, Pinv
595594
valid_inverses = (MatrixInverse, MatrixPinv)
596595
core_op = node.op.core_op
597596
if not (isinstance(core_op, valid_inverses)):
@@ -609,15 +608,15 @@ def rewrite_inv_for_diag_eye_mul(fgraph, node):
609608
if len(non_eye_inputs) != 1:
610609
return None
611610

612-
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
611+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
613612

614-
# For a matrix, we can first get the diagonal and then only use those
615-
if useful_non_eye.type.broadcastable[-2:] == (False, False):
613+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
614+
if non_eye_input.type.broadcastable[-2:] == (False, False):
616615
# For Matrix
617-
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)]
616+
return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)]
618617
else:
619618
# For Scalar/Vector
620-
return [useful_eye * 1 / useful_non_eye]
619+
return [eye_input / non_eye_input]
621620

622621

623622
def rewrite_inv_for_diag_ptdiag(fgraph, node):
@@ -656,7 +655,7 @@ def rewrite_inv_for_orthonormal(fgraph, node):
656655
):
657656
return None
658657

659-
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)
658+
# To make sure input is orthonormal, we have to check that its not S (output order of SVD is U, S, Vh, so S is index 1) (S matrix consists of singular values and it is not orthonormal)
660659
if input_to_inv == input_to_inv.owner.outputs[1]:
661660
return None
662661

tests/tensor/rewriting/test_linalg.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
from tests.test_rop import break_op
4141

4242

43+
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
44+
45+
4346
def test_rop_lop():
4447
mx = matrix("mx")
4548
mv = matrix("mv")
@@ -589,8 +592,8 @@ def test_inv_diag_from_eye_mul(shape):
589592
assert_allclose(
590593
inverse_matrix,
591594
rewritten_inverse,
592-
atol=1e-3 if config.floatX == "float32" else 1e-8,
593-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
595+
atol=ATOL,
596+
rtol=RTOL,
594597
)
595598

596599

@@ -599,7 +602,7 @@ def test_inv_orthonormal():
599602
u, s, vh = pt.linalg.svd(x)
600603
# Calculating inverse using pt.linalg.inv
601604
u_inv = pt.linalg.inv(u)
602-
print(u_inv.dprint())
605+
603606
# REWRITE TEST
604607
f_rewritten = function([x], u_inv, mode="FAST_RUN")
605608
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -616,6 +619,6 @@ def test_inv_orthonormal():
616619
assert_allclose(
617620
inverse_matrix,
618621
rewritten_inverse,
619-
atol=1e-3 if config.floatX == "float32" else 1e-8,
620-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
622+
atol=ATOL,
623+
rtol=RTOL,
621624
)

0 commit comments

Comments
 (0)