From fc29a91b4f5b2017ccac6201fdb1813a8a837e92 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sun, 15 Dec 2024 00:44:04 +0530 Subject: [PATCH 1/2] Added eig rewrite for diagonal matrix --- pytensor/tensor/rewriting/linalg.py | 68 +++++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 79 +++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cd202fe3ed..c6a094a8a2 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -34,6 +34,7 @@ MatrixPinv, SLogDet, det, + eig, inv, kron, pinv, @@ -1013,3 +1014,70 @@ def slogdet_specialization(fgraph, node): k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() } return replacements + + +@register_canonicalize +@register_stabilize +@node_rewriter([eig]) +def rewrite_eig_diag(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + inputs = node.inputs[0] + + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + eigval_rewritten = pt.diag(inputs) + eigvec_rewritten = pt.eye(inputs.shape[-1]) + return [eigval_rewritten, eigvec_rewritten] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + eye_input, non_eye_input = eye_input, non_eye_inputs[0] + # eigval_rewritten = pt.diag(non_eye_input) + eigvec_rewritten = eye_input + + # Checking if original x was scalar/vector/matrix + if non_eye_input.type.broadcastable[-2:] == (True, True): + # For scalar + eigval_rewritten = pt.full( + (eye_input.shape[0],), non_eye_input.squeeze(axis=(-1, -2)) + ) + elif non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix + eigval_rewritten = pt.diag(non_eye_input) + else: + # For vector + eigval_rewritten = non_eye_input.squeeze() + + return [eigval_rewritten, eigvec_rewritten] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..a9afa5a0e5 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -18,6 +18,7 @@ from pytensor.tensor.nlinalg import ( SVD, Det, + Eig, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -996,3 +997,81 @@ def test_slogdet_specialization(): f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix"], +) +def test_eig_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + + # Calculating eigval and eigvec using pt.linalg.eig + eigval, eigvec = pt.linalg.eig(y) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert not any( + isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig) + for node in nodes + ) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + + x_test_matrix = np.eye(7) * x_test + eigval, eigvec = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_diag_from_diag(): + x = pt.tensor("x", shape=(None,)) + x_diag = pt.diag(x) + eigval, eigvec = pt.linalg.eig(x_diag) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + f_rewritten.dprint() + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Eig) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(7).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + eigval, eigvec = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From bdd167912395974f87c2e6d6f67c6c46bdebcaae Mon Sep 17 00:00:00 2001 From: Tanish Date: Wed, 18 Dec 2024 21:53:12 +0530 Subject: [PATCH 2/2] added rewrite for eig when input matrix is identity --- pytensor/tensor/rewriting/linalg.py | 36 ++++++++++++++++++++++++++- tests/tensor/rewriting/test_linalg.py | 30 ++++++++++++++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index c6a094a8a2..da16acf1a4 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1016,12 +1016,46 @@ def slogdet_specialization(fgraph, node): return replacements +@register_canonicalize +@register_stabilize +@node_rewriter([eig]) +def rewrite_eig_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check whether input to Eig is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + + eigval_rewritten = pt.ones(potential_eye.shape[-1]) + eigvec_rewritten = pt.eye(potential_eye.shape[-1]) + + return [eigval_rewritten, eigvec_rewritten] + + @register_canonicalize @register_stabilize @node_rewriter([eig]) def rewrite_eig_diag(fgraph, node): """ - This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix. + This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis. The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index a9afa5a0e5..fa9c5f84e6 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1047,14 +1047,40 @@ def test_eig_diag_from_eye_mul(shape): ) -def test_eig_diag_from_diag(): +def test_eig_eye(): + x = pt.eye(10) + eigval, eigvec = pt.linalg.eig(x) + + # REWRITE TEST + f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Eig) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.eye(10) + eigval, eigvec = np.linalg.eig(x_test) + rewritten_eigval, rewritten_eigvec = f_rewritten() + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_diag(): x = pt.tensor("x", shape=(None,)) x_diag = pt.diag(x) eigval, eigvec = pt.linalg.eig(x_diag) # REWRITE TEST f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") - f_rewritten.dprint() nodes = f_rewritten.maker.fgraph.apply_nodes assert not any(isinstance(node.op, Eig) for node in nodes)