-
Notifications
You must be signed in to change notification settings - Fork 133
Added inv(diag) -> (1/diag) rewrite for eye mul #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5062e4f
9bf0077
fb8e429
00a8613
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,12 +5,14 @@ | |||||||||
from pytensor import Variable | ||||||||||
from pytensor.graph import Apply, FunctionGraph | ||||||||||
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter | ||||||||||
from pytensor.tensor.basic import TensorVariable, diagonal | ||||||||||
from pytensor.scalar.basic import Mul | ||||||||||
from pytensor.tensor.basic import Eye, TensorVariable, diagonal | ||||||||||
from pytensor.tensor.blas import Dot22 | ||||||||||
from pytensor.tensor.blockwise import Blockwise | ||||||||||
from pytensor.tensor.elemwise import DimShuffle | ||||||||||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||||||||||
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod | ||||||||||
from pytensor.tensor.nlinalg import ( | ||||||||||
SVD, | ||||||||||
KroneckerProduct, | ||||||||||
MatrixInverse, | ||||||||||
MatrixPinv, | ||||||||||
|
@@ -377,3 +379,119 @@ def local_lift_through_linalg( | |||||||||
return [block_diag(*inner_matrices)] | ||||||||||
else: | ||||||||||
raise NotImplementedError # pragma: no cover | ||||||||||
|
||||||||||
|
||||||||||
def _find_diag_from_eye_mul(potential_mul_input): | ||||||||||
# Check if the op is Elemwise and mul | ||||||||||
if not ( | ||||||||||
potential_mul_input.owner is not None | ||||||||||
and isinstance(potential_mul_input.owner.op, Elemwise) | ||||||||||
and isinstance(potential_mul_input.owner.op.scalar_op, Mul) | ||||||||||
): | ||||||||||
return None | ||||||||||
|
||||||||||
# Find whether any of the inputs to mul is Eye | ||||||||||
inputs_to_mul = potential_mul_input.owner.inputs | ||||||||||
eye_input = [ | ||||||||||
mul_input | ||||||||||
for mul_input in inputs_to_mul | ||||||||||
if mul_input.owner and isinstance(mul_input.owner.op, Eye) | ||||||||||
] | ||||||||||
# Check if 1's are being put on the main diagonal only (k = 1) | ||||||||||
k = getattr(eye_input[0].owner.inputs[-1], "data", 0).item() | ||||||||||
if k != 0: | ||||||||||
return None | ||||||||||
|
||||||||||
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite | ||||||||||
if eye_input and eye_input[0].broadcastable[-2:] != (False, False): | ||||||||||
return None | ||||||||||
|
||||||||||
# Get all non Eye inputs (scalars/matrices/vectors) | ||||||||||
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) | ||||||||||
return eye_input, non_eye_inputs | ||||||||||
|
||||||||||
|
||||||||||
@register_canonicalize | ||||||||||
@register_stabilize | ||||||||||
@node_rewriter([Blockwise]) | ||||||||||
def rewrite_inv_for_diag_and_orthonormal(fgraph, node): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo this should be split into two rewrites There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay i'll do that |
||||||||||
""" | ||||||||||
This rewrite covers a few cases which take advantage of the fact that : | ||||||||||
1. for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. | ||||||||||
2. for an orthonormal matrix, the inverse is simply the transpose. | ||||||||||
|
||||||||||
For simplicity, this function deals with the following cases : | ||||||||||
1. for a diagonal matrix : | ||||||||||
i) arising from the multiplicaton of eye with a scalar/vector/matrix | ||||||||||
ii) arising from pt.diag of a vector | ||||||||||
iii) solve(x, eye) directly returns inv(x) | ||||||||||
2. for an orthonormal matrix : | ||||||||||
i) arising from pt.linalg.svd decomposition (U, Vh) | ||||||||||
ii) arising from pt.linalg.qr | ||||||||||
|
||||||||||
Parameters | ||||||||||
---------- | ||||||||||
fgraph: FunctionGraph | ||||||||||
Function graph being optimized | ||||||||||
node: Apply | ||||||||||
Node of the function graph to be optimized | ||||||||||
|
||||||||||
Returns | ||||||||||
------- | ||||||||||
list of Variable, optional | ||||||||||
List of optimized variables, or None if no optimization was performed | ||||||||||
""" | ||||||||||
# List of useful operations : Inv, Pinv, Solve | ||||||||||
core_op = node.op.core_op | ||||||||||
if not ( | ||||||||||
isinstance(core_op, inv) | ||||||||||
or isinstance(core_op, pinv) | ||||||||||
or isinstance(core_op, solve) | ||||||||||
): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are other solve Ops that could represent an inversion, notable |
||||||||||
return None | ||||||||||
|
||||||||||
# Dealing with direct inverse Ops | ||||||||||
if isinstance(core_op, inv) or isinstance(core_op, pinv): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
# Dealing with diagonal matrix from eye_mul | ||||||||||
potential_mul_input = node.inputs[0] | ||||||||||
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) | ||||||||||
if eye_non_eye_inputs is not None: | ||||||||||
eye_input, non_eye_inputs = eye_non_eye_inputs | ||||||||||
|
||||||||||
# Dealing with only one other input | ||||||||||
if len(non_eye_inputs) != 1: | ||||||||||
return None | ||||||||||
|
||||||||||
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note sure what makes them "useful" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just the fact that im using them 😛 i'll update the names |
||||||||||
|
||||||||||
# For a matrix, we can first get the diagonal and then only use those | ||||||||||
if useful_non_eye.type.broadcastable[-2:] == (False, False): | ||||||||||
# For Matrix | ||||||||||
return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)] | ||||||||||
else: | ||||||||||
# For Scalar/Vector | ||||||||||
return [useful_eye * 1 / useful_non_eye] | ||||||||||
|
||||||||||
# Dealing with orthonormal matrix from SVD | ||||||||||
else: | ||||||||||
# Check if input to Inverse is coming from SVD | ||||||||||
input_to_inv = node.inputs[0] | ||||||||||
# Check if this input is coming from SVD with compute_uv = True | ||||||||||
if not ( | ||||||||||
isinstance(input_to_inv.owner.op, Blockwise) | ||||||||||
and isinstance(input_to_inv.owner.op.core_op, SVD) | ||||||||||
and input_to_inv.owner.op.core_op.compute_uv is True | ||||||||||
): | ||||||||||
return None | ||||||||||
|
||||||||||
# To make sure input is orthonormal, we have to check that its not S | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
if input_to_inv == input_to_inv.owner.outputs[1]: | ||||||||||
return None | ||||||||||
|
||||||||||
orthonormal_input = input_to_inv | ||||||||||
inverse = orthonormal_input.T | ||||||||||
return [inverse] | ||||||||||
Comment on lines
+551
to
+553
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
# Dealing with solve(x, eye) | ||||||||||
elif isinstance(core_op, solve): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most of the solve logic is going to be the same as the inv/pinv, except that you just have to check that B is eye, then extract the A argument. It probably won't be best to split it out this way |
||||||||||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -390,3 +390,40 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): | |
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] | ||
|
||
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], | ||
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], | ||
) | ||
def test_inv_diag_from_eye_mul(shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test for orthonormal matrix? |
||
# Initializing x based on scalar/vector/matrix | ||
x = pt.tensor("x", shape=shape) | ||
x_diag = pt.eye(7) * x | ||
# Calculating inverse using pt.linalg.inv | ||
x_inv = pt.linalg.inv(x_diag) | ||
|
||
# REWRITE TEST | ||
f_rewritten = function([x], x_inv, mode="FAST_RUN") | ||
nodes = f_rewritten.maker.fgraph.apply_nodes | ||
assert not any(isinstance(node.op, MatrixInverse) for node in nodes) | ||
assert not any(isinstance(node.op, MatrixPinv) for node in nodes) | ||
|
||
# NUMERIC VALUE TEST | ||
if len(shape) == 0: | ||
x_test = np.array(np.random.rand()).astype(config.floatX) | ||
elif len(shape) == 1: | ||
x_test = np.random.rand(*shape).astype(config.floatX) | ||
else: | ||
x_test = np.random.rand(*shape).astype(config.floatX) | ||
x_test_matrix = np.eye(7) * x_test | ||
inverse_matrix = np.linalg.inv(x_test_matrix) | ||
rewritten_inverse = f_rewritten(x_test) | ||
|
||
assert_allclose( | ||
inverse_matrix, | ||
rewritten_inverse, | ||
atol=1e-3 if config.floatX == "float32" else 1e-8, | ||
rtol=1e-3 if config.floatX == "float32" else 1e-8, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this duplicated from the other PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea. that one isnt merged yet so i had to add this function in this one too