Skip to content

Commit 5ee276a

Browse files
committed
added docstrings for rewrite and helper
1 parent ea42689 commit 5ee276a

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -572,17 +572,22 @@ def svd_uv_merge(fgraph, node):
572572
return [cl.outputs[1]]
573573

574574

575-
def _find_solve_with_eye(node):
575+
def _find_solve_with_eye(node) -> bool:
576+
"""
577+
The result of solve(A, b) is the solution x to the linear equation Ax = b. If b is an identity matrix (Eye), x is simply inv(A).
578+
Here, we are just recognising whether the solve operation returns an inverse or not; not replacing it because solve is mathematically more stable than inv.
579+
"""
576580
valid_solves = (Solve, SolveTriangular)
577-
# First, we look for the solve op
581+
# First, we verify whether we have a valid solve op
578582
if not (
579583
isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_solves)
580584
):
581585
return False
582-
# Check whether second input to solve is Eye
586+
# If the current op is solve, we check for b. If b is an identity matrix (Eye), we can return True
583587
solve_inputs = node.inputs
584588
eye_input = solve_inputs[1].owner
585589

590+
# We check for b = Eye and also make sure that if it was an Eye, then k = 0 (1's are only across the main diagonal)
586591
if not (eye_input and isinstance(eye_input.op, Eye)):
587592
return False
588593

@@ -595,9 +600,29 @@ def _find_solve_with_eye(node):
595600
@register_stabilize
596601
@node_rewriter([Blockwise])
597602
def rewrite_inv_inv(fgraph, node):
603+
"""
604+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
605+
606+
Here, we check for direct inverse operations (inv/pinv) and also solve operations (solve/solve_triangular) in the case when b = Eye. This allows any combination of these "inverse" nodes to be simply rewritten.
607+
608+
Parameters
609+
----------
610+
fgraph: FunctionGraph
611+
Function graph being optimized
612+
node: Apply
613+
Node of the function graph to be optimized
614+
615+
Returns
616+
-------
617+
list of Variable, optional
618+
List of optimized variables, or None if no optimization was performed
619+
"""
598620
valid_inverses = (MatrixInverse, MatrixPinv, Solve, SolveTriangular)
599621
valid_solves = (Solve, SolveTriangular)
600622
# Check if its a valid inverse operation (either inv/pinv or if its solve, then b = eye)
623+
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
624+
# If the outer operation is a solve op with b = Eye, it treats it as inverse and finds the inner operation
625+
# If the outer operation is not an inverse (neither inv nor solve with eye), we do not apply this rewrite
601626
inv_check = False
602627
if isinstance(node.op, Blockwise) and isinstance(node.op.core_op, valid_inverses):
603628
inv_check = True
@@ -611,16 +636,17 @@ def rewrite_inv_inv(fgraph, node):
611636
if potential_inner_inv is None or potential_inner_inv.op is None:
612637
return None
613638

614-
# Check if its a valid inverse operation (either inv/pinv or if its solve, then b = eye)
615-
inv_check = False
639+
# Similar to the check for outer operation, we now run the same checks for the inner op.
640+
# If its an inverse or solve with eye, we apply the rewrite. Otherwise, we return None.
641+
inv_check_inner = False
616642
if isinstance(potential_inner_inv.op, Blockwise) and isinstance(
617643
potential_inner_inv.op.core_op, valid_inverses
618644
):
619-
inv_check = True
645+
inv_check_inner = True
620646
if isinstance(potential_inner_inv.op.core_op, valid_solves):
621-
inv_check = _find_solve_with_eye(potential_inner_inv)
647+
inv_check_inner = _find_solve_with_eye(potential_inner_inv)
622648

623-
if not inv_check:
649+
if not inv_check_inner:
624650
return None
625651

626652
if not (

tests/tensor/rewriting/test_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def test_transinv_to_invtrans():
9494
X = matrix("X")
9595
Y = matrix_inverse(X)
9696
Z = Y.transpose()
97-
print(Z.dprint())
9897
f = pytensor.function([X], Z)
99-
print(f.dprint())
10098
if config.mode != "FAST_COMPILE":
10199
for node in f.maker.fgraph.toposort():
102100
if isinstance(node.op, MatrixInverse):

0 commit comments

Comments
 (0)