diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 30d9084449..027d907a63 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -8,10 +8,11 @@ copy_stack_trace, node_rewriter, ) -from pytensor.tensor.basic import TensorVariable, diagonal +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import Eye, TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( SVD, @@ -438,3 +439,119 @@ def svd_uv_merge(fgraph, node): or len(fgraph.clients[cl.outputs[2]]) > 0 ): return [cl.outputs[1]] + + +def _find_diag_from_eye_mul(potential_mul_input): + # Check if the op is Elemwise and mul + if not ( + potential_mul_input.owner is not None + and isinstance(potential_mul_input.owner.op, Elemwise) + and isinstance(potential_mul_input.owner.op.scalar_op, Mul) + ): + return None + + # Find whether any of the inputs to mul is Eye + inputs_to_mul = potential_mul_input.owner.inputs + eye_input = [ + mul_input + for mul_input in inputs_to_mul + if mul_input.owner and isinstance(mul_input.owner.op, Eye) + ] + # Check if 1's are being put on the main diagonal only (k = 1) + k = getattr(eye_input[0].owner.inputs[-1], "data", 0).item() + if k != 0: + return None + + # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite + if eye_input and eye_input[0].broadcastable[-2:] != (False, False): + return None + + # Get all non Eye inputs (scalars/matrices/vectors) + non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + return eye_input, non_eye_inputs + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_for_diag_and_orthonormal(fgraph, node): + """ + This rewrite covers a few cases which take advantage of the fact that : + 1. for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. + 2. for an orthonormal matrix, the inverse is simply the transpose. + + For simplicity, this function deals with the following cases : + 1. for a diagonal matrix : + i) arising from the multiplicaton of eye with a scalar/vector/matrix + ii) arising from pt.diag of a vector + iii) solve(x, eye) directly returns inv(x) + 2. for an orthonormal matrix : + i) arising from pt.linalg.svd decomposition (U, Vh) + ii) arising from pt.linalg.qr + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # List of useful operations : Inv, Pinv, Solve + core_op = node.op.core_op + if not ( + isinstance(core_op, inv) + or isinstance(core_op, pinv) + or isinstance(core_op, solve) + ): + return None + + # Dealing with direct inverse Ops + if isinstance(core_op, inv) or isinstance(core_op, pinv): + # Dealing with diagonal matrix from eye_mul + potential_mul_input = node.inputs[0] + eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) + if eye_non_eye_inputs is not None: + eye_input, non_eye_inputs = eye_non_eye_inputs + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] + + # For a matrix, we can first get the diagonal and then only use those + if useful_non_eye.type.broadcastable[-2:] == (False, False): + # For Matrix + return [useful_eye * 1 / useful_non_eye.diagonal(axis1=-1, axis2=-2)] + else: + # For Scalar/Vector + return [useful_eye * 1 / useful_non_eye] + + # Dealing with orthonormal matrix from SVD + else: + # Check if input to Inverse is coming from SVD + input_to_inv = node.inputs[0] + # Check if this input is coming from SVD with compute_uv = True + if not ( + isinstance(input_to_inv.owner.op, Blockwise) + and isinstance(input_to_inv.owner.op.core_op, SVD) + and input_to_inv.owner.op.core_op.compute_uv is True + ): + return None + + # To make sure input is orthonormal, we have to check that its not S + if input_to_inv == input_to_inv.owner.outputs[1]: + return None + + orthonormal_input = input_to_inv + inverse = orthonormal_input.T + return [inverse] + + # Dealing with solve(x, eye) + elif isinstance(core_op, solve): + pass diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 523742e356..a83e492c0a 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -456,3 +456,40 @@ def test_svd_uv_merge(): assert node.op.compute_uv svd_counter += 1 assert svd_counter == 1 + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], +) +def test_inv_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + x_diag = pt.eye(7) * x + # Calculating inverse using pt.linalg.inv + x_inv = pt.linalg.inv(x_diag) + + # REWRITE TEST + f_rewritten = function([x], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, MatrixInverse) for node in nodes) + assert not any(isinstance(node.op, MatrixPinv) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + inverse_matrix = np.linalg.inv(x_test_matrix) + rewritten_inverse = f_rewritten(x_test) + + assert_allclose( + inverse_matrix, + rewritten_inverse, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )