Skip to content

Added inv(diag) -> (1/diag) rewrite for eye mul #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 120 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from pytensor import Variable
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import Eye, TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
Expand Down Expand Up @@ -377,3 +379,119 @@ def local_lift_through_linalg(
return [block_diag(*inner_matrices)]
else:
raise NotImplementedError # pragma: no cover


def _find_diag_from_eye_mul(potential_mul_input):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplicated from the other PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea. that one isnt merged yet so i had to add this function in this one too

# Check if the op is Elemwise and mul
if not (
potential_mul_input.owner is not None
and isinstance(potential_mul_input.owner.op, Elemwise)
and isinstance(potential_mul_input.owner.op.scalar_op, Mul)
):
return None

# Find whether any of the inputs to mul is Eye
inputs_to_mul = potential_mul_input.owner.inputs
eye_input = [
mul_input
for mul_input in inputs_to_mul
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
]
# Check if 1's are being put on the main diagonal only (k = 1)
k = getattr(eye_input[0].owner.inputs[-1], "data", 0).item()
if k != 0:
return None

# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
return None

# Get all non Eye inputs (scalars/matrices/vectors)
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
return eye_input, non_eye_inputs


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_for_diag_and_orthonormal(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo this should be split into two rewrites

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay i'll do that

"""
This rewrite covers a few cases which take advantage of the fact that :
1. for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
2. for an orthonormal matrix, the inverse is simply the transpose.

For simplicity, this function deals with the following cases :
1. for a diagonal matrix :
i) arising from the multiplicaton of eye with a scalar/vector/matrix
ii) arising from pt.diag of a vector
iii) solve(x, eye) directly returns inv(x)
2. for an orthonormal matrix :
i) arising from pt.linalg.svd decomposition (U, Vh)
ii) arising from pt.linalg.qr

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# List of useful operations : Inv, Pinv, Solve
core_op = node.op.core_op
if not (
isinstance(core_op, inv)
or isinstance(core_op, pinv)
or isinstance(core_op, solve)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other solve Ops that could represent an inversion, notable solve_triangular

return None

# Dealing with direct inverse Ops
if isinstance(core_op, inv) or isinstance(core_op, pinv):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(core_op, inv) or isinstance(core_op, pinv):
if isinstance(core_op, inv | pinv):

# Dealing with diagonal matrix from eye_mul
potential_mul_input = node.inputs[0]
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
if eye_non_eye_inputs is not None:
eye_input, non_eye_inputs = eye_non_eye_inputs

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
eye, non_eye = eye_input[0], non_eye_inputs[0]

Note sure what makes them "useful"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the fact that im using them 😛 i'll update the names


# For a matrix, we can first get the diagonal and then only use those
if useful_non_eye.type.broadcastable[-2:] == (False, False):
# For Matrix
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)]
else:
# For Scalar/Vector
return [useful_eye * 1 / useful_non_eye]

# Dealing with orthonormal matrix from SVD
else:
# Check if input to Inverse is coming from SVD
input_to_inv = node.inputs[0]
# Check if this input is coming from SVD with compute_uv = True
if not (
isinstance(input_to_inv.owner.op, Blockwise)
and isinstance(input_to_inv.owner.op.core_op, SVD)
and input_to_inv.owner.op.core_op.compute_uv is True
):
return None

# To make sure input is orthonormal, we have to check that its not S
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# To make sure input is orthonormal, we have to check that its not S
# To make sure input is orthonormal, we have to check that its not S (output order is U, S, Vh, so S is index 1)

if input_to_inv == input_to_inv.owner.outputs[1]:
return None

orthonormal_input = input_to_inv
inverse = orthonormal_input.T
return [inverse]
Comment on lines +551 to +553
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
orthonormal_input = input_to_inv
inverse = orthonormal_input.T
return [inverse]
return [input_to_inv.T]


# Dealing with solve(x, eye)
elif isinstance(core_op, solve):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the solve logic is going to be the same as the inv/pinv, except that you just have to check that B is eye, then extract the A argument. It probably won't be best to split it out this way

pass
37 changes: 37 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,40 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]

np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)


@pytest.mark.parametrize(
"shape",
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
)
def test_inv_diag_from_eye_mul(shape):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test for orthonormal matrix?

# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
x_diag = pt.eye(7) * x
# Calculating inverse using pt.linalg.inv
x_inv = pt.linalg.inv(x_diag)

# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, MatrixInverse) for node in nodes)
assert not any(isinstance(node.op, MatrixPinv) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)

assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)