Skip to content

Add rewrites for inv(diag(x)) and inv(eye) #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 93 additions & 3 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast

from pytensor import Variable
from pytensor import tensor as pt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import what you need directly from the various packages within pytensor.tensor, it makes the code easier to follow for people in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah that makes sense. i'll do that from next time cuz the test file has way too many uses of pt rn 😓

from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
Expand Down Expand Up @@ -48,6 +49,7 @@


logger = logging.getLogger(__name__)
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)


def is_matrix_transpose(x: TensorVariable) -> bool:
Expand Down Expand Up @@ -592,11 +594,10 @@
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
return None

potential_inner_inv = node.inputs[0].owner
Expand All @@ -607,7 +608,96 @@
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
):
return None
return [potential_inner_inv.inputs[0]]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_eye_to_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None

# Check whether input to inverse is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
):
return None
return [potential_eye]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None

inputs = node.inputs[0]
# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
inv_input = inputs.owner.inputs[0]
inv_val = pt.diag(1 / inv_input)
return [inv_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

Check warning on line 694 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L694

Added line #L694 was not covered by tests

non_eye_input = non_eye_inputs[0]

# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2)
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)

return [eye_input / non_eye_input]
100 changes: 97 additions & 3 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from tests.test_rop import break_op


ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8


def test_rop_lop():
mx = matrix("mx")
mv = matrix("mv")
Expand Down Expand Up @@ -557,14 +560,105 @@ def test_svd_uv_merge():
assert svd_counter == 1


def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)


@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)

x = pt.matrix("x")
op1 = get_pt_function(x, inv_op_1)
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x


@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_eye_to_eye(inv_op):
x = pt.eye(10)
x_inv = get_pt_function(x, inv_op)
f_rewritten = function([], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

# Rewrite Test
valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# Value Test
x_test = np.eye(10)
x_inv_val = np.linalg.inv(x_test)
rewritten_val = f_rewritten()

assert_allclose(
x_inv_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_eye_mul(shape, inv_op):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
x_diag = pt.eye(7) * x
# Calculating inverse using pt.linalg.inv
x_inv = get_pt_function(x_diag, inv_op)

# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)

assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)


@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_diag(inv_op):
x = pt.dvector("x")
x_diag = pt.diag(x)
x_inv = get_pt_function(x_diag, inv_op)

# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)

assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)
Loading