From 82b7e9838db76f295ac0405cc9a9ada0a50fe2a4 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 2 Jun 2024 18:02:01 +0530 Subject: [PATCH 01/31] Added det-diag rewrite --- pytensor/tensor/rewriting/linalg.py | 58 ++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cdb1e59101..73e7b742a1 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,10 +5,10 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.basic import TensorVariable, diagonal +from pytensor.tensor.basic import TensorVariable, diagonal, Eye, Mul from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( KroneckerProduct, @@ -18,6 +18,7 @@ inv, kron, pinv, + Det, ) from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -377,3 +378,56 @@ def local_lift_through_linalg( return [block_diag(*inner_matrices)] else: raise NotImplementedError # pragma: no cover + + +# Det Diag Rewrite +def det_diag_rewrite(node: Apply): + # Find if we have Blockwise Op + if not( + isinstance(node.op, Blockwise) + and isinstance (node.op.core_op, Det) + ): + return None + + node_input_op = node.inputs[0] + # Check if the op is Elemwise and mul + if not( + isinstance(node_input_op.owner.op, Elemwise) + and isinstance(node_input_op.owner.op.scalar_op, Mul) + ): + return None + + # Find whether any of the inputs to mul is Eye + inputs_to_mul = node_input_op.owner.inputs + eye_check = False + eye_input = [] + for mul_input in inputs_to_mul: + if isinstance(mul_input.owner.op, Eye): + eye_check = True + eye_input.append(mul_input) + break # This is so that we only find the first one? Can be changed + if not ( + eye_check + ): + return None + + # Get all non Eye inputs (scalars/matrices/vectors) + non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + + # Dealing with only one other input + if (len(non_eye_inputs) >= 2): + raise NotImplementedError + + # Checking if original x was matrix + if not non_eye_inputs[0].owner: + # It was a matrix + det_val = pt.diagonal(non_eye_inputs[0]).prod() + else: + # Check for scalar (ndim = 0) or vector (ndim = 1) + sca_vec_input = non_eye_inputs[0].owner.inputs[0] + if (sca_vec_input.ndim == 0): + det_val = sca_vec_input**eye_input[0].shape[0] + else: + det_val = sca_vec_input.prod() + + return [det_val] \ No newline at end of file From e661ba416c33e60da9a7efcc9b598abbb0bfb737 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 2 Jun 2024 18:17:08 +0530 Subject: [PATCH 02/31] fixed pt.diagonal error --- pytensor/tensor/rewriting/linalg.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 73e7b742a1..5b5b318a7c 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,12 +5,13 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.basic import TensorVariable, diagonal, Eye, Mul +from pytensor.tensor.basic import Eye, Mul, TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( + Det, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -18,7 +19,6 @@ inv, kron, pinv, - Det, ) from pytensor.tensor.rewriting.basic import ( register_canonicalize, @@ -383,15 +383,12 @@ def local_lift_through_linalg( # Det Diag Rewrite def det_diag_rewrite(node: Apply): # Find if we have Blockwise Op - if not( - isinstance(node.op, Blockwise) - and isinstance (node.op.core_op, Det) - ): + if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Det)): return None node_input_op = node.inputs[0] # Check if the op is Elemwise and mul - if not( + if not ( isinstance(node_input_op.owner.op, Elemwise) and isinstance(node_input_op.owner.op.scalar_op, Mul) ): @@ -405,29 +402,27 @@ def det_diag_rewrite(node: Apply): if isinstance(mul_input.owner.op, Eye): eye_check = True eye_input.append(mul_input) - break # This is so that we only find the first one? Can be changed - if not ( - eye_check - ): + break # This is so that we only find the first one? Can be changed + if not (eye_check): return None # Get all non Eye inputs (scalars/matrices/vectors) non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) - + # Dealing with only one other input - if (len(non_eye_inputs) >= 2): + if len(non_eye_inputs) >= 2: raise NotImplementedError # Checking if original x was matrix if not non_eye_inputs[0].owner: # It was a matrix - det_val = pt.diagonal(non_eye_inputs[0]).prod() + det_val = non_eye_inputs[0].diagonal().prod() else: # Check for scalar (ndim = 0) or vector (ndim = 1) sca_vec_input = non_eye_inputs[0].owner.inputs[0] - if (sca_vec_input.ndim == 0): - det_val = sca_vec_input**eye_input[0].shape[0] + if sca_vec_input.ndim == 0: + det_val = sca_vec_input ** eye_input[0].shape[0] else: det_val = sca_vec_input.prod() - return [det_val] \ No newline at end of file + return [det_val] From a1823e98b6ba69c203b9dbd7db05a26729450527 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 5 Jun 2024 16:43:52 +0530 Subject: [PATCH 03/31] Added test for rewrite --- pytensor/tensor/rewriting/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 5b5b318a7c..7979f2201d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -5,7 +5,8 @@ from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.basic import Eye, Mul, TensorVariable, diagonal +from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import Eye, TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise From ac21b9e0b1a5eb8d5328d7b6ddd3838c7513a632 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 5 Jun 2024 16:44:31 +0530 Subject: [PATCH 04/31] Added test for rewrite --- tests/tensor/rewriting/test_linalg.py | 33 ++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 1e9d6194db..dc75f38217 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -21,7 +21,7 @@ MatrixPinv, matrix_inverse, ) -from pytensor.tensor.rewriting.linalg import inv_as_solve +from pytensor.tensor.rewriting.linalg import det_diag_rewrite, inv_as_solve from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, @@ -390,3 +390,34 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals] np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) + + +@pytest.mark.paramterize( + "i,x_test", [(0, 4), (1, [4, 3, 2, 5]), (2, [[3, 7, 4], [2, 5, 6], [6, 2, 3]])] +) +def test_det_diag_rewrite(i, x_test): + # Initialising x based on scalar/vector/matrix + if i == 0: + x = pt.dscalar("x") + y = pt.eye(7) * x + elif i == 1: + x = pt.dvector("x") + y = pt.eye(x.shape[0]) * x + elif i == 2: + x = pt.dmatrix("x") + y = pt.eye(x.shape[0]) * x + # Caluclating determinant value using pt.linalg.det + z_det = pt.linalg.det(y) + f_det = function([x], z_det) + det_val = f_det(x_test) + # Applying the det diag rewrite + [rewritten] = det_diag_rewrite(z_det.owner) + f_rewritten = function([x], rewritten) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + det_val, + rewritten_val, + atol=1e-4 if config.floatX == "float32" else 1e-8, + rtol=1e-4 if config.floatX == "float32" else 1e-8, + ) From 01e47d7903e4324bc205013e5e9bf8fb1879efb0 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 5 Jun 2024 17:02:26 +0530 Subject: [PATCH 05/31] fixed test --- tests/tensor/rewriting/test_linalg.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index dc75f38217..1a1ca7e430 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -392,11 +392,16 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) -@pytest.mark.paramterize( +@pytest.fixture +def x_test_cases(): + return [(0, 4), (1, [4, 3, 2, 5]), (2, [[3, 7, 4], [2, 5, 6], [6, 2, 3]])] + + +@pytest.mark.parametrize( "i,x_test", [(0, 4), (1, [4, 3, 2, 5]), (2, [[3, 7, 4], [2, 5, 6], [6, 2, 3]])] ) def test_det_diag_rewrite(i, x_test): - # Initialising x based on scalar/vector/matrix + # Initializing x based on scalar/vector/matrix if i == 0: x = pt.dscalar("x") y = pt.eye(7) * x @@ -406,7 +411,7 @@ def test_det_diag_rewrite(i, x_test): elif i == 2: x = pt.dmatrix("x") y = pt.eye(x.shape[0]) * x - # Caluclating determinant value using pt.linalg.det + # Calculating determinant value using pt.linalg.det z_det = pt.linalg.det(y) f_det = function([x], z_det) det_val = f_det(x_test) From 89f209ff8f06ed33b6136b69cf23a1e963e79ed0 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Thu, 6 Jun 2024 14:49:04 +0530 Subject: [PATCH 06/31] added check for verifying rewrite --- pytensor/tensor/rewriting/linalg.py | 5 +++- tests/tensor/rewriting/test_linalg.py | 38 +++++++++++++-------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 7979f2201d..15ec04cc3b 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -382,7 +382,10 @@ def local_lift_through_linalg( # Det Diag Rewrite -def det_diag_rewrite(node: Apply): +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def det_diag_rewrite(fgraph, node): # Find if we have Blockwise Op if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Det)): return None diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 1a1ca7e430..d957412f02 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -21,7 +21,7 @@ MatrixPinv, matrix_inverse, ) -from pytensor.tensor.rewriting.linalg import det_diag_rewrite, inv_as_solve +from pytensor.tensor.rewriting.linalg import inv_as_solve from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, @@ -392,32 +392,30 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8) -@pytest.fixture -def x_test_cases(): - return [(0, 4), (1, [4, 3, 2, 5]), (2, [[3, 7, 4], [2, 5, 6], [6, 2, 3]])] - - @pytest.mark.parametrize( - "i,x_test", [(0, 4), (1, [4, 3, 2, 5]), (2, [[3, 7, 4], [2, 5, 6], [6, 2, 3]])] + "shape", [(), (7,), (7, 7)], ids=["scalar", "vector", "matrix"] ) -def test_det_diag_rewrite(i, x_test): +def test_det_diag_rewrite(shape): # Initializing x based on scalar/vector/matrix - if i == 0: - x = pt.dscalar("x") - y = pt.eye(7) * x - elif i == 1: - x = pt.dvector("x") - y = pt.eye(x.shape[0]) * x - elif i == 2: - x = pt.dmatrix("x") - y = pt.eye(x.shape[0]) * x + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x # Calculating determinant value using pt.linalg.det z_det = pt.linalg.det(y) + + # REWRITE TEST + f_rewritten = function([x], z_det, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Det) for node in nodes) + + # NUMERIC VALUE TEST f_det = function([x], z_det) + if len(shape) == 0: + x_test = np.random.rand() + elif len(shape) == 1: + x_test = np.random.rand(7) + else: + x_test = np.random.rand(7, 7) det_val = f_det(x_test) - # Applying the det diag rewrite - [rewritten] = det_diag_rewrite(z_det.owner) - f_rewritten = function([x], rewritten) rewritten_val = f_rewritten(x_test) assert_allclose( From 0bd2e267a3ea51b9d2046fdb0b54e38877081e24 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Thu, 6 Jun 2024 19:02:44 +0530 Subject: [PATCH 07/31] fixed other failing test --- pytensor/tensor/rewriting/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 15ec04cc3b..cd1e26c7be 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -393,7 +393,8 @@ def det_diag_rewrite(fgraph, node): node_input_op = node.inputs[0] # Check if the op is Elemwise and mul if not ( - isinstance(node_input_op.owner.op, Elemwise) + node_input_op.owner is not None + and isinstance(node_input_op.owner.op, Elemwise) and isinstance(node_input_op.owner.op.scalar_op, Mul) ): return None From 170a860884f6807faf1931e2ce3bdb27ef109341 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 7 Jun 2024 11:53:51 +0530 Subject: [PATCH 08/31] added docstring --- pytensor/tensor/rewriting/linalg.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cd1e26c7be..a6ad1986a6 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -385,7 +385,27 @@ def local_lift_through_linalg( @register_canonicalize @register_stabilize @node_rewriter([det]) -def det_diag_rewrite(fgraph, node): +def det_diag_rewrite(fgraph: FunctionGraph, node: Apply) -> list[Variable] or None: + """ + Rewrites the determinant of a diagonal matrix into a simpler computation by using information about how the diagonal matrix arises. + + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. + Scalar : If we multiply a scalar with an identity matrix, the determinant is the the number of diagonal elements times the scalar value + Vector : If we multiply a vector with an identity matrix, the determinant is the product of the elements of the vector (which lie along the diagonal of the final matrix) + Matrix : If we multiply a matrix with another identity matrix, the determinant is the product of diagonal elements of the original matrix + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ # Find if we have Blockwise Op if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Det)): return None From 0191b977e17bf0f71c4ef19e242895629bc537b9 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 7 Jun 2024 15:58:33 +0530 Subject: [PATCH 09/31] updated docstring --- pytensor/tensor/rewriting/linalg.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a6ad1986a6..28db1768d7 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -387,12 +387,9 @@ def local_lift_through_linalg( @node_rewriter([det]) def det_diag_rewrite(fgraph: FunctionGraph, node: Apply) -> list[Variable] or None: """ - Rewrites the determinant of a diagonal matrix into a simpler computation by using information about how the diagonal matrix arises. + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. - This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. - Scalar : If we multiply a scalar with an identity matrix, the determinant is the the number of diagonal elements times the scalar value - Vector : If we multiply a vector with an identity matrix, the determinant is the product of the elements of the vector (which lie along the diagonal of the final matrix) - Matrix : If we multiply a matrix with another identity matrix, the determinant is the product of diagonal elements of the original matrix + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. Parameters ---------- From 7f7c8034c5720ab14413ceb57a2f7dbca00fe687 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 7 Jun 2024 18:20:24 +0530 Subject: [PATCH 10/31] fixed mypy error --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 28db1768d7..3691423906 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -385,7 +385,7 @@ def local_lift_through_linalg( @register_canonicalize @register_stabilize @node_rewriter([det]) -def det_diag_rewrite(fgraph: FunctionGraph, node: Apply) -> list[Variable] or None: +def det_diag_rewrite(fgraph, node): """ This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. From 8d7f9f733e7674dfc211b8ebe777f531079bd995 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 9 Jun 2024 18:25:59 +0530 Subject: [PATCH 11/31] added det_diag_from_diag and test --- pytensor/tensor/rewriting/linalg.py | 39 ++++++++++++++++++++++----- tests/tensor/rewriting/test_linalg.py | 36 ++++++++++++++++++++----- 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 3691423906..a5f64f7f69 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,11 +2,15 @@ from collections.abc import Callable from typing import cast -from pytensor import Variable +from pytensor import Variable, config from pytensor.graph import Apply, FunctionGraph -from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.graph.rewriting.basic import ( + PatternNodeRewriter, + copy_stack_trace, + node_rewriter, +) from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import Eye, TensorVariable, diagonal +from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -36,6 +40,7 @@ solve, solve_triangular, ) +from pytensor.tensor.subtensor import advanced_set_subtensor logger = logging.getLogger(__name__) @@ -381,11 +386,11 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover -# Det Diag Rewrite +# Det Diag Rewrites @register_canonicalize @register_stabilize @node_rewriter([det]) -def det_diag_rewrite(fgraph, node): +def rewrite_det_diag_from_eye_mul(fgraph, node): """ This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. @@ -443,8 +448,30 @@ def det_diag_rewrite(fgraph, node): # Check for scalar (ndim = 0) or vector (ndim = 1) sca_vec_input = non_eye_inputs[0].owner.inputs[0] if sca_vec_input.ndim == 0: - det_val = sca_vec_input ** eye_input[0].shape[0] + det_val = sca_vec_input ** (eye_input[0].shape[0]).astype(config.floatX) else: det_val = sca_vec_input.prod() return [det_val] + + +# Det diag from diag +arange = ARange("int64") +det_diag_from_diag = PatternNodeRewriter( + ( + det, + ( + advanced_set_subtensor, + (alloc, 0, "sh1", "sh2"), + "x", + (arange, 0, "stop", 1), + (arange, 0, "stop", 1), + ), + ), + (prod, "x"), + name="determinant_of_diagonal", + allow_multiple_clients=True, +) +register_canonicalize(det_diag_from_diag) +register_stabilize(det_diag_from_diag) +register_specialize(det_diag_from_diag) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index d957412f02..5ce7ccd907 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -395,7 +395,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): @pytest.mark.parametrize( "shape", [(), (7,), (7, 7)], ids=["scalar", "vector", "matrix"] ) -def test_det_diag_rewrite(shape): +def test_det_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) y = pt.eye(7) * x @@ -410,17 +410,41 @@ def test_det_diag_rewrite(shape): # NUMERIC VALUE TEST f_det = function([x], z_det) if len(shape) == 0: - x_test = np.random.rand() + x_test = np.array(np.random.rand()).astype(config.floatX) elif len(shape) == 1: - x_test = np.random.rand(7) + x_test = np.random.rand(7).astype(config.floatX) else: - x_test = np.random.rand(7, 7) + x_test = np.random.rand(7, 7).astype(config.floatX) det_val = f_det(x_test) rewritten_val = f_rewritten(x_test) assert_allclose( det_val, rewritten_val, - atol=1e-4 if config.floatX == "float32" else 1e-8, - rtol=1e-4 if config.floatX == "float32" else 1e-8, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_det_diag_from_diag(): + x = pt.tensor("x", shape=(None,)) + x_diag = pt.diag(x) + y = pt.linalg.det(x_diag) + + # REWRITE TEST + f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Det) for node in nodes) + + # NUMERIC VALUE TEST + f_det = function([x], y) + x_test = np.random.rand(7).astype(config.floatX) + det_val = f_det(x_test) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, ) From 9ad0606eadb4279eafb7e216a0daaff10b7047d4 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 9 Jun 2024 18:26:45 +0530 Subject: [PATCH 12/31] fixed node rewriter name --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a5f64f7f69..865a89b928 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -469,7 +469,7 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): ), ), (prod, "x"), - name="determinant_of_diagonal", + name="det_diag_from_diag", allow_multiple_clients=True, ) register_canonicalize(det_diag_from_diag) From c8972963ea858eec60c357e2e170047e014efd66 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Tue, 11 Jun 2024 14:17:02 +0530 Subject: [PATCH 13/31] added row/col tests --- pytensor/tensor/rewriting/linalg.py | 7 +------ tests/tensor/rewriting/test_linalg.py | 8 +++++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 865a89b928..f7c17f142d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -16,7 +16,6 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod from pytensor.tensor.nlinalg import ( - Det, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -408,10 +407,6 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - # Find if we have Blockwise Op - if not (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Det)): - return None - node_input_op = node.inputs[0] # Check if the op is Elemwise and mul if not ( @@ -438,7 +433,7 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): # Dealing with only one other input if len(non_eye_inputs) >= 2: - raise NotImplementedError + return None # Checking if original x was matrix if not non_eye_inputs[0].owner: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 5ce7ccd907..58c3646e43 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -393,7 +393,9 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): @pytest.mark.parametrize( - "shape", [(), (7,), (7, 7)], ids=["scalar", "vector", "matrix"] + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix"], ) def test_det_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix @@ -412,9 +414,9 @@ def test_det_diag_from_eye_mul(shape): if len(shape) == 0: x_test = np.array(np.random.rand()).astype(config.floatX) elif len(shape) == 1: - x_test = np.random.rand(7).astype(config.floatX) + x_test = np.random.rand(*shape).astype(config.floatX) else: - x_test = np.random.rand(7, 7).astype(config.floatX) + x_test = np.random.rand(*shape).astype(config.floatX) det_val = f_det(x_test) rewritten_val = f_rewritten(x_test) From 6c3cdaea4d22913a155550b19891bfb9ad4da499 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Tue, 11 Jun 2024 14:33:53 +0530 Subject: [PATCH 14/31] updated check for eye --- pytensor/tensor/rewriting/linalg.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index f7c17f142d..b031a4fbe3 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -418,14 +418,12 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): # Find whether any of the inputs to mul is Eye inputs_to_mul = node_input_op.owner.inputs - eye_check = False - eye_input = [] - for mul_input in inputs_to_mul: - if isinstance(mul_input.owner.op, Eye): - eye_check = True - eye_input.append(mul_input) - break # This is so that we only find the first one? Can be changed - if not (eye_check): + eye_input = [ + mul_input + for mul_input in inputs_to_mul + if mul_input.owner and isinstance(mul_input.owner.op, Eye) + ] + if not (eye_input[0]): return None # Get all non Eye inputs (scalars/matrices/vectors) From 6b58cfb1e116b13a5627c6cb12dfdb3c1c5237dd Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 15 Jun 2024 19:23:00 +0530 Subject: [PATCH 15/31] updated rewrite and tests --- pytensor/tensor/rewriting/linalg.py | 34 ++++++++++++++------------- tests/tensor/rewriting/test_linalg.py | 8 +++---- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index b031a4fbe3..bb0d640351 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -407,23 +407,23 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - node_input_op = node.inputs[0] + node_input = node.inputs[0] # Check if the op is Elemwise and mul if not ( - node_input_op.owner is not None - and isinstance(node_input_op.owner.op, Elemwise) - and isinstance(node_input_op.owner.op.scalar_op, Mul) + node_input.owner is not None + and isinstance(node_input.owner.op, Elemwise) + and isinstance(node_input.owner.op.scalar_op, Mul) ): return None # Find whether any of the inputs to mul is Eye - inputs_to_mul = node_input_op.owner.inputs + inputs_to_mul = node_input.owner.inputs eye_input = [ mul_input for mul_input in inputs_to_mul if mul_input.owner and isinstance(mul_input.owner.op, Eye) ] - if not (eye_input[0]): + if not (eye_input): return None # Get all non Eye inputs (scalars/matrices/vectors) @@ -433,17 +433,19 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): if len(non_eye_inputs) >= 2: return None - # Checking if original x was matrix - if not non_eye_inputs[0].owner: - # It was a matrix - det_val = non_eye_inputs[0].diagonal().prod() + # Checking if original x was scalar/vector/matrix + + if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): + # For scalar + det_val = non_eye_inputs[0].owner.inputs[0] ** (eye_input[0].shape[0]).astype( + config.floatX + ) + elif non_eye_inputs[0].type.broadcastable[-2:] == (False, False): + # For Matrix + det_val = non_eye_inputs[0].diagonal().prod(axis=-1) else: - # Check for scalar (ndim = 0) or vector (ndim = 1) - sca_vec_input = non_eye_inputs[0].owner.inputs[0] - if sca_vec_input.ndim == 0: - det_val = sca_vec_input ** (eye_input[0].shape[0]).astype(config.floatX) - else: - det_val = sca_vec_input.prod() + # For vector + det_val = non_eye_inputs[0].prod(axis=(-1, -2)) return [det_val] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 58c3646e43..b208eac916 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -394,8 +394,8 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): @pytest.mark.parametrize( "shape", - [(), (7,), (1, 7), (7, 1), (7, 7)], - ids=["scalar", "vector", "row_vec", "col_vec", "matrix"], + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], ) def test_det_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix @@ -410,14 +410,14 @@ def test_det_diag_from_eye_mul(shape): assert not any(isinstance(node.op, Det) for node in nodes) # NUMERIC VALUE TEST - f_det = function([x], z_det) if len(shape) == 0: x_test = np.array(np.random.rand()).astype(config.floatX) elif len(shape) == 1: x_test = np.random.rand(*shape).astype(config.floatX) else: x_test = np.random.rand(*shape).astype(config.floatX) - det_val = f_det(x_test) + x_test_matrix = np.eye(7) * x_test + det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) assert_allclose( From 4316e75486691430164752b70d48922ec2b7944e Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Fri, 21 Jun 2024 20:38:44 +0530 Subject: [PATCH 16/31] added check for eye_input and new test for cases where not to apply rewrite --- pytensor/tensor/rewriting/linalg.py | 16 ++++++++----- tests/tensor/rewriting/test_linalg.py | 34 +++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index bb0d640351..71d2f57170 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -423,7 +423,7 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): for mul_input in inputs_to_mul if mul_input.owner and isinstance(mul_input.owner.op, Eye) ] - if not (eye_input): + if eye_input and eye_input[0].broadcastable[-2:] != (False, False): return None # Get all non Eye inputs (scalars/matrices/vectors) @@ -433,16 +433,20 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): if len(non_eye_inputs) >= 2: return None - # Checking if original x was scalar/vector/matrix + # Rewrite is only applied if all the shapes are known + # if (non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[-2:] == (None, None)): + # return None + # Otherwise, cases such as recantangle eye (pt.eye(7,5)) or degenerate eye (pt.eye(1)) will also be rewritten incorrectly. + # Checking if original x was scalar/vector/matrix if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): # For scalar - det_val = non_eye_inputs[0].owner.inputs[0] ** (eye_input[0].shape[0]).astype( - config.floatX - ) + det_val = non_eye_inputs[0].squeeze(axis=(-1, -2)) ** ( + eye_input[0].shape[0] + ).astype(config.floatX) elif non_eye_inputs[0].type.broadcastable[-2:] == (False, False): # For Matrix - det_val = non_eye_inputs[0].diagonal().prod(axis=-1) + det_val = non_eye_inputs[0].diagonal(axis1=-1, axis2=-2).prod(axis=-1) else: # For vector det_val = non_eye_inputs[0].prod(axis=(-1, -2)) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index b208eac916..c3e8b9c6d9 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -439,9 +439,9 @@ def test_det_diag_from_diag(): assert not any(isinstance(node.op, Det) for node in nodes) # NUMERIC VALUE TEST - f_det = function([x], y) x_test = np.random.rand(7).astype(config.floatX) - det_val = f_det(x_test) + x_test_matrix = np.eye(7) * x_test + det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) assert_allclose( @@ -450,3 +450,33 @@ def test_det_diag_from_diag(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +# degenrate eye : pt.eye(1) +# rectangle : non square eye +@pytest.mark.parametrize( + "shape", [(1, 1), (7, 5)], ids=["degenerate_eye", "rectangle_eye"] +) +def test_dont_apply_det_diag_rewrite(shape): + x = pt.matrix("x") + x_diag = pt.eye(*shape) * x + y = pt.linalg.det(x_diag) + f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, Det) for node in nodes) + + # NUMERIC VALUE TEST + shape_new = (2, 2) if shape == (1, 1) else shape + x_test = np.random.normal(size=shape_new) + x_test_matrix = np.eye(*shape) * x_test + + if shape != (7, 5): + det_val = np.linalg.det(x_test_matrix) + rewritten_val = f_rewritten(x_test) + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From 6743e0aa782a75a55005e7619ced438bfe3d6eae Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 22 Jun 2024 09:59:22 +0530 Subject: [PATCH 17/31] does not apply rewrite to specific cases --- pytensor/tensor/rewriting/linalg.py | 8 +++++--- tests/tensor/rewriting/test_linalg.py | 10 +++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 71d2f57170..ebcbb52328 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -434,9 +434,11 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): return None # Rewrite is only applied if all the shapes are known - # if (non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[-2:] == (None, None)): - # return None - # Otherwise, cases such as recantangle eye (pt.eye(7,5)) or degenerate eye (pt.eye(1)) will also be rewritten incorrectly. + if non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[ + -2: + ] == (None, None): + return None + # This ensures that the rewrite is NOT applied to cases of degenerate eye (pt.eye(1)) and rectangle eye (pt.eye(7,5)) # Checking if original x was scalar/vector/matrix if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c3e8b9c6d9..176a6215a5 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -452,8 +452,6 @@ def test_det_diag_from_diag(): ) -# degenrate eye : pt.eye(1) -# rectangle : non square eye @pytest.mark.parametrize( "shape", [(1, 1), (7, 5)], ids=["degenerate_eye", "rectangle_eye"] ) @@ -466,12 +464,10 @@ def test_dont_apply_det_diag_rewrite(shape): assert any(isinstance(node.op, Det) for node in nodes) - # NUMERIC VALUE TEST - shape_new = (2, 2) if shape == (1, 1) else shape - x_test = np.random.normal(size=shape_new) - x_test_matrix = np.eye(*shape) * x_test - + # NUMERIC VALUE TEST (only in case of (1,1)) if shape != (7, 5): + x_test = np.random.normal(size=(3, 3)) + x_test_matrix = np.eye(*shape) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) assert_allclose( From ebca33959304c13b40932da2ce5c58b72e95cd8d Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sat, 22 Jun 2024 14:35:45 +0530 Subject: [PATCH 18/31] typecasted test variable --- tests/tensor/rewriting/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 176a6215a5..fc9e17304d 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -466,7 +466,7 @@ def test_dont_apply_det_diag_rewrite(shape): # NUMERIC VALUE TEST (only in case of (1,1)) if shape != (7, 5): - x_test = np.random.normal(size=(3, 3)) + x_test = np.random.normal(size=(3, 3)).astype(config.floatX) x_test_matrix = np.eye(*shape) * x_test det_val = np.linalg.det(x_test_matrix) rewritten_val = f_rewritten(x_test) From 58212b68397c4d8af4f51bec6e8939c22e884772 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Sun, 23 Jun 2024 12:41:52 +0530 Subject: [PATCH 19/31] typecast variables --- pytensor/tensor/rewriting/linalg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index ebcbb52328..c60a319723 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -9,6 +9,7 @@ copy_stack_trace, node_rewriter, ) +from pytensor.raise_op import Assert from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal from pytensor.tensor.blas import Dot22 @@ -43,6 +44,9 @@ logger = logging.getLogger(__name__) +assert_square_matrix_for_det = Assert( + "Last 2 dimensions of the input tensor to det are not equal and thus, it is not square!" +) def is_matrix_transpose(x: TensorVariable) -> bool: @@ -385,7 +389,6 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover -# Det Diag Rewrites @register_canonicalize @register_stabilize @node_rewriter([det]) @@ -423,6 +426,7 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): for mul_input in inputs_to_mul if mul_input.owner and isinstance(mul_input.owner.op, Eye) ] + # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite if eye_input and eye_input[0].broadcastable[-2:] != (False, False): return None @@ -438,7 +442,6 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): -2: ] == (None, None): return None - # This ensures that the rewrite is NOT applied to cases of degenerate eye (pt.eye(1)) and rectangle eye (pt.eye(7,5)) # Checking if original x was scalar/vector/matrix if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): @@ -456,7 +459,6 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): return [det_val] -# Det diag from diag arange = ARange("int64") det_diag_from_diag = PatternNodeRewriter( ( From 3ef186d6ccb28786ad24ce2137f0014c4587d713 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 24 Jun 2024 11:35:50 +0530 Subject: [PATCH 20/31] removed shape known check; fails for rectangle eye --- pytensor/tensor/rewriting/linalg.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index c60a319723..dc0465bf1d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -9,7 +9,6 @@ copy_stack_trace, node_rewriter, ) -from pytensor.raise_op import Assert from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal from pytensor.tensor.blas import Dot22 @@ -44,9 +43,6 @@ logger = logging.getLogger(__name__) -assert_square_matrix_for_det = Assert( - "Last 2 dimensions of the input tensor to det are not equal and thus, it is not square!" -) def is_matrix_transpose(x: TensorVariable) -> bool: @@ -434,14 +430,14 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) # Dealing with only one other input - if len(non_eye_inputs) >= 2: + if len(non_eye_inputs) != 1: return None # Rewrite is only applied if all the shapes are known - if non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[ - -2: - ] == (None, None): - return None + # if non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[ + # -2: + # ] == (None, None): + # return None # Checking if original x was scalar/vector/matrix if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): @@ -455,7 +451,6 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): else: # For vector det_val = non_eye_inputs[0].prod(axis=(-1, -2)) - return [det_val] From 2c96faf93be5581d980380023d1958fb20039fc3 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 24 Jun 2024 13:45:27 +0530 Subject: [PATCH 21/31] added new tests for (1,1) eye and rectangle eye --- pytensor/tensor/rewriting/linalg.py | 32 +++++++++++-------- tests/tensor/rewriting/test_linalg.py | 46 +++++++++++++++++---------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index dc0465bf1d..418738352a 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -385,8 +385,8 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover -@register_canonicalize -@register_stabilize +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") @node_rewriter([det]) def rewrite_det_diag_from_eye_mul(fgraph, node): """ @@ -440,18 +440,24 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): # return None # Checking if original x was scalar/vector/matrix - if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): - # For scalar - det_val = non_eye_inputs[0].squeeze(axis=(-1, -2)) ** ( - eye_input[0].shape[0] - ).astype(config.floatX) - elif non_eye_inputs[0].type.broadcastable[-2:] == (False, False): - # For Matrix - det_val = non_eye_inputs[0].diagonal(axis1=-1, axis2=-2).prod(axis=-1) + if ( + eye_input[0].type.shape[-1] is not None + and eye_input[0].type.shape[-2] is not None + ) and (eye_input[0].type.shape[-1] == eye_input[0].type.shape[-2]): + if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): + # For scalar + det_val = non_eye_inputs[0].squeeze(axis=(-1, -2)) ** ( + eye_input[0].shape[0] + ).astype(config.floatX) + elif non_eye_inputs[0].type.broadcastable[-2:] == (False, False): + # For Matrix + det_val = non_eye_inputs[0].diagonal(axis1=-1, axis2=-2).prod(axis=-1) + else: + # For vector + det_val = non_eye_inputs[0].prod(axis=(-1, -2)) + return [det_val] else: - # For vector - det_val = non_eye_inputs[0].prod(axis=(-1, -2)) - return [det_val] + return None arange = ARange("int64") diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index fc9e17304d..450614dbde 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -452,27 +452,39 @@ def test_det_diag_from_diag(): ) -@pytest.mark.parametrize( - "shape", [(1, 1), (7, 5)], ids=["degenerate_eye", "rectangle_eye"] -) -def test_dont_apply_det_diag_rewrite(shape): +def test_dont_apply_det_diag_rewrite_for_1_1(): x = pt.matrix("x") - x_diag = pt.eye(*shape) * x + x_diag = pt.eye(1, 1) * x y = pt.linalg.det(x_diag) f_rewritten = function([x], y, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Det) for node in nodes) - # NUMERIC VALUE TEST (only in case of (1,1)) - if shape != (7, 5): - x_test = np.random.normal(size=(3, 3)).astype(config.floatX) - x_test_matrix = np.eye(*shape) * x_test - det_val = np.linalg.det(x_test_matrix) - rewritten_val = f_rewritten(x_test) - assert_allclose( - det_val, - rewritten_val, - atol=1e-3 if config.floatX == "float32" else 1e-8, - rtol=1e-3 if config.floatX == "float32" else 1e-8, - ) + # Numeric Value test + x_test = np.random.normal(size=(3, 3)).astype(config.floatX) + x_test_matrix = np.eye(1, 1) * x_test + det_val = np.linalg.det(x_test_matrix) + rewritten_val = f_rewritten(x_test) + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_det_diag_incorrect_for_rectangle_eye(): + x = pt.matrix("x") + x_diag = pt.eye(7, 5) * x + y = pt.linalg.det(x_diag) + f_rewritten = function([x], y, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert any(isinstance(node.op, Det) for node in nodes) + + # Numeric Value test (should fail) + x_test = np.random.normal(size=(7, 5)).astype(config.floatX) + x_test_matrix = np.eye(7, 5) * x_test + with pytest.raises(np.linalg.LinAlgError, match="Last 2 dimensions"): + np.linalg.det(x_test_matrix) From 9a48a727c81200b47b2015c66565f269801f5852 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Tue, 25 Jun 2024 20:23:53 +0530 Subject: [PATCH 22/31] added helper function for diag from eye_mul --- pytensor/tensor/rewriting/linalg.py | 49 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 418738352a..7fb82c8ab5 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -385,27 +385,7 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover -@register_canonicalize("shape_unsafe") -@register_stabilize("shape_unsafe") -@node_rewriter([det]) -def rewrite_det_diag_from_eye_mul(fgraph, node): - """ - This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. - - The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ +def _find_diag_from_eye_mul(node): node_input = node.inputs[0] # Check if the op is Elemwise and mul if not ( @@ -428,6 +408,33 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): # Get all non Eye inputs (scalars/matrices/vectors) non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) + return eye_input, non_eye_inputs + + +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@node_rewriter([det]) +def rewrite_det_diag_from_eye_mul(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + eye_non_eye_inputs = _find_diag_from_eye_mul(node) + if eye_non_eye_inputs is not None: + eye_input, non_eye_inputs = eye_non_eye_inputs # Dealing with only one other input if len(non_eye_inputs) != 1: From ffc0f819109e3cf4a8aff1ef96460cdd03e0cbad Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Tue, 25 Jun 2024 20:56:08 +0530 Subject: [PATCH 23/31] updated case for no rewrite which was failing tests --- pytensor/tensor/rewriting/linalg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 7fb82c8ab5..d5fdce67bd 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -435,6 +435,8 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): eye_non_eye_inputs = _find_diag_from_eye_mul(node) if eye_non_eye_inputs is not None: eye_input, non_eye_inputs = eye_non_eye_inputs + else: + return None # Dealing with only one other input if len(non_eye_inputs) != 1: From 9f661e3bffc231b75bad2fac9620454ef4fe86ff Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 26 Jun 2024 16:08:01 +0530 Subject: [PATCH 24/31] cleaned code; updated rectangle_eye test which is an invalid rewrite --- pytensor/tensor/rewriting/linalg.py | 53 +++++++++++---------------- tests/tensor/rewriting/test_linalg.py | 3 +- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d5fdce67bd..83186e050d 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -385,18 +385,17 @@ def local_lift_through_linalg( raise NotImplementedError # pragma: no cover -def _find_diag_from_eye_mul(node): - node_input = node.inputs[0] +def _find_diag_from_eye_mul(potential_mul_input): # Check if the op is Elemwise and mul if not ( - node_input.owner is not None - and isinstance(node_input.owner.op, Elemwise) - and isinstance(node_input.owner.op.scalar_op, Mul) + potential_mul_input.owner is not None + and isinstance(potential_mul_input.owner.op, Elemwise) + and isinstance(potential_mul_input.owner.op.scalar_op, Mul) ): return None # Find whether any of the inputs to mul is Eye - inputs_to_mul = node_input.owner.inputs + inputs_to_mul = potential_mul_input.owner.inputs eye_input = [ mul_input for mul_input in inputs_to_mul @@ -432,41 +431,31 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - eye_non_eye_inputs = _find_diag_from_eye_mul(node) - if eye_non_eye_inputs is not None: - eye_input, non_eye_inputs = eye_non_eye_inputs - else: + potential_mul_input = node.inputs[0] + eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) + if eye_non_eye_inputs is None: return None + eye_input, non_eye_inputs = eye_non_eye_inputs # Dealing with only one other input if len(non_eye_inputs) != 1: return None - # Rewrite is only applied if all the shapes are known - # if non_eye_inputs[0].type.shape[-2:] == (None, None) or eye_input[0].type.shape[ - # -2: - # ] == (None, None): - # return None + useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] # Checking if original x was scalar/vector/matrix - if ( - eye_input[0].type.shape[-1] is not None - and eye_input[0].type.shape[-2] is not None - ) and (eye_input[0].type.shape[-1] == eye_input[0].type.shape[-2]): - if non_eye_inputs[0].type.broadcastable[-2:] == (True, True): - # For scalar - det_val = non_eye_inputs[0].squeeze(axis=(-1, -2)) ** ( - eye_input[0].shape[0] - ).astype(config.floatX) - elif non_eye_inputs[0].type.broadcastable[-2:] == (False, False): - # For Matrix - det_val = non_eye_inputs[0].diagonal(axis1=-1, axis2=-2).prod(axis=-1) - else: - # For vector - det_val = non_eye_inputs[0].prod(axis=(-1, -2)) - return [det_val] + if useful_non_eye.type.broadcastable[-2:] == (True, True): + # For scalar + det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]).astype( + config.floatX + ) + elif useful_non_eye.type.broadcastable[-2:] == (False, False): + # For Matrix + det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) else: - return None + # For vector + det_val = useful_non_eye.prod(axis=(-1, -2)) + return [det_val] arange = ARange("int64") diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 450614dbde..9ee6d5e7c4 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -481,7 +481,8 @@ def test_det_diag_incorrect_for_rectangle_eye(): f_rewritten = function([x], y, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes - assert any(isinstance(node.op, Det) for node in nodes) + assert not any(isinstance(node.op, Det) for node in nodes) + # This assert passes which means that the rewrite is applied even if the input is not square # Numeric Value test (should fail) x_test = np.random.normal(size=(7, 5)).astype(config.floatX) From b02818db29f45f2c147307f0ec768554eb530f7a Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 26 Jun 2024 19:34:50 +0530 Subject: [PATCH 25/31] add check for k in pt.eye --- pytensor/tensor/rewriting/linalg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 83186e050d..d3faf2cc29 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -401,6 +401,11 @@ def _find_diag_from_eye_mul(potential_mul_input): for mul_input in inputs_to_mul if mul_input.owner and isinstance(mul_input.owner.op, Eye) ] + + # Check if 1's are being put on the main diagonal only (k = 1) + if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", 0).item() != 0: + return None + # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite if eye_input and eye_input[0].broadcastable[-2:] != (False, False): return None From 0dbc28e6dbfbf229f87756053de5fc3157befc93 Mon Sep 17 00:00:00 2001 From: Tanish Date: Wed, 26 Jun 2024 19:47:58 +0530 Subject: [PATCH 26/31] Update pytensor/tensor/rewriting/linalg.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d3faf2cc29..17de8832e8 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -403,7 +403,7 @@ def _find_diag_from_eye_mul(potential_mul_input): ] # Check if 1's are being put on the main diagonal only (k = 1) - if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", 0).item() != 0: + if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: return None # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite From d4ffb7f0cdbf803f10bb18ee29d73d2cd996cf2f Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 26 Jun 2024 23:37:16 +0530 Subject: [PATCH 27/31] typecasted det_val --- pytensor/tensor/rewriting/linalg.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d3faf2cc29..66691dbbd3 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,7 +2,7 @@ from collections.abc import Callable from typing import cast -from pytensor import Variable, config +from pytensor import Variable from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( PatternNodeRewriter, @@ -451,15 +451,14 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): # Checking if original x was scalar/vector/matrix if useful_non_eye.type.broadcastable[-2:] == (True, True): # For scalar - det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]).astype( - config.floatX - ) + det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) elif useful_non_eye.type.broadcastable[-2:] == (False, False): # For Matrix det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) else: # For vector det_val = useful_non_eye.prod(axis=(-1, -2)) + det_val = det_val.astype(useful_non_eye.dtype) return [det_val] From 557fea1057326f0a8d683118677ceaf899b3c84b Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Thu, 27 Jun 2024 00:58:44 +0530 Subject: [PATCH 28/31] fixed final typecasting --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index f28c31873d..f3a1d4898b 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -458,7 +458,7 @@ def rewrite_det_diag_from_eye_mul(fgraph, node): else: # For vector det_val = useful_non_eye.prod(axis=(-1, -2)) - det_val = det_val.astype(useful_non_eye.dtype) + det_val = det_val.astype(node.outputs[0].type.dtype) return [det_val] From 27a98641375ed7ab6dd91ca64850b47aff7773e7 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 1 Jul 2024 15:27:46 +0530 Subject: [PATCH 29/31] fixed merge --- tests/tensor/rewriting/test_linalg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index ffc33a3042..c3fb8d5479 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -492,6 +492,7 @@ def test_det_diag_incorrect_for_rectangle_eye(): with pytest.raises(np.linalg.LinAlgError, match="Last 2 dimensions"): np.linalg.det(x_test_matrix) + def test_svd_uv_merge(): a = matrix("a") s_1 = svd(a, full_matrices=False, compute_uv=False) From 683106901739e3032fd3d47e5bf835c42047aeee Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Mon, 1 Jul 2024 16:18:39 +0530 Subject: [PATCH 30/31] fixed failing rectangle eye test --- tests/tensor/rewriting/test_linalg.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c3fb8d5479..d59e3cc88f 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -479,18 +479,8 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): def test_det_diag_incorrect_for_rectangle_eye(): x = pt.matrix("x") x_diag = pt.eye(7, 5) * x - y = pt.linalg.det(x_diag) - f_rewritten = function([x], y, mode="FAST_RUN") - nodes = f_rewritten.maker.fgraph.apply_nodes - - assert not any(isinstance(node.op, Det) for node in nodes) - # This assert passes which means that the rewrite is applied even if the input is not square - - # Numeric Value test (should fail) - x_test = np.random.normal(size=(7, 5)).astype(config.floatX) - x_test_matrix = np.eye(7, 5) * x_test - with pytest.raises(np.linalg.LinAlgError, match="Last 2 dimensions"): - np.linalg.det(x_test_matrix) + with pytest.raises(ValueError, match="Determinant not defined"): + pt.linalg.det(x_diag) def test_svd_uv_merge(): From 9811b880b934502f70641975a1808d45c491dab9 Mon Sep 17 00:00:00 2001 From: Tanish Taneja Date: Wed, 3 Jul 2024 16:10:42 +0530 Subject: [PATCH 31/31] fixed typo --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 4b5014b41d..3c98834c94 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -404,7 +404,7 @@ def _find_diag_from_eye_mul(potential_mul_input): if mul_input.owner and isinstance(mul_input.owner.op, Eye) ] - # Check if 1's are being put on the main diagonal only (k = 1) + # Check if 1's are being put on the main diagonal only (k = 0) if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: return None