From 6eefbbe2d6b141ad94b8b4ed8010f59dd908a7c6 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 10 Aug 2024 16:00:28 +0530 Subject: [PATCH 1/4] added rewrite for diag(block_diag) --- pytensor/tensor/rewriting/linalg.py | 24 ++++++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 27 +++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 47ca08cf21..e5e0302474 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -12,8 +12,11 @@ from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ( AllocDiag, + ExtractDiag, Eye, TensorVariable, + concatenate, + diag, diagonal, ) from pytensor.tensor.blas import Dot22 @@ -701,3 +704,24 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): non_eye_input = pt.shape_padaxis(non_eye_diag, -2) return [eye_input / non_eye_input] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_blockdiag(fgraph, node): + # Check for inner block_diag operation + potential_blockdiag = node.inputs[0].owner + if not ( + potential_blockdiag + and isinstance(potential_blockdiag.op, Blockwise) + and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + submatrices = potential_blockdiag.inputs + submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] + output = [concatenate(submatrices_diag)] + + return output diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 0bee56eb30..45cbe9c969 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -662,3 +662,30 @@ def test_inv_diag_from_diag(inv_op): atol=ATOL, rtol=RTOL, ) + + +def test_diag_blockdiag_rewrite(): + n_matrices = 100 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + diag_output = pt.diag(bd_output) + f_rewritten = function([sub_matrices], diag_output, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + diag_output_test = np.diag(bd_output_test) + rewritten_val = f_rewritten(sub_matrices_test) + assert_allclose( + diag_output_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From 28d397b31f482da37f0b0a3e33cff5e37deecc68 Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 16 Aug 2024 23:46:50 +0530 Subject: [PATCH 2/4] added rewrite for determinant of blockdiag --- pytensor/tensor/rewriting/linalg.py | 21 +++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 27 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index e5e0302474..701a0f840e 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -725,3 +725,24 @@ def rewrite_diag_blockdiag(fgraph, node): output = [concatenate(submatrices_diag)] return output + + +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def rewrite_det_blockdiag(fgraph, node): + # Check for inner block_diag operation + potential_blockdiag = node.inputs[0].owner + if not ( + potential_blockdiag + and isinstance(potential_blockdiag.op, Blockwise) + and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_blockdiag.inputs + det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] + prod_det_sub_matrices = prod(det_sub_matrices) + + return [prod_det_sub_matrices] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 45cbe9c969..3d112c4551 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -689,3 +689,30 @@ def test_diag_blockdiag_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_det_blockdiag_rewrite(): + n_matrices = 100 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + det_output = pt.linalg.det(bd_output) + f_rewritten = function([sub_matrices], det_output, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + det_output_test = np.linalg.det(bd_output_test) + rewritten_val = f_rewritten(sub_matrices_test) + assert_allclose( + det_output_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From 48f6527acb746716e991cb1d3cec13435fa52006 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sun, 18 Aug 2024 23:13:18 +0530 Subject: [PATCH 3/4] Added rewrite for slogdet; added docstrings for all 3 rewrites --- pytensor/tensor/rewriting/linalg.py | 100 ++++++++++++++++++++++---- tests/tensor/rewriting/test_linalg.py | 37 +++++++++- 2 files changed, 122 insertions(+), 15 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 701a0f840e..89ce9e8c73 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -32,6 +32,7 @@ inv, kron, pinv, + slogdet, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -710,39 +711,110 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): @register_stabilize @node_rewriter([ExtractDiag]) def rewrite_diag_blockdiag(fgraph, node): + """ + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. + + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ # Check for inner block_diag operation - potential_blockdiag = node.inputs[0].owner + potential_block_diag = node.inputs[0].owner if not ( - potential_blockdiag - and isinstance(potential_blockdiag.op, Blockwise) - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) ): return None # Find the composing sub_matrices - submatrices = potential_blockdiag.inputs + submatrices = potential_block_diag.inputs submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] - output = [concatenate(submatrices_diag)] - return output + return [concatenate(submatrices_diag)] @register_canonicalize @register_stabilize @node_rewriter([det]) def rewrite_det_blockdiag(fgraph, node): + """ + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. + + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ # Check for inner block_diag operation - potential_blockdiag = node.inputs[0].owner + potential_block_diag = node.inputs[0].owner if not ( - potential_blockdiag - and isinstance(potential_blockdiag.op, Blockwise) - and isinstance(potential_blockdiag.op.core_op, BlockDiagonal) + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) ): return None # Find the composing sub_matrices - sub_matrices = potential_blockdiag.inputs + sub_matrices = potential_block_diag.inputs det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] - prod_det_sub_matrices = prod(det_sub_matrices) - return [prod_det_sub_matrices] + return [prod(det_sub_matrices)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([slogdet]) +def rewrite_slogdet_blockdiag(fgraph, node): + """ + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + + slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_block_diag.inputs + sign_sub_matrices, logdet_sub_matrices = zip( + *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] + ) + + return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 3d112c4551..db17a802a8 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -665,7 +665,7 @@ def test_inv_diag_from_diag(inv_op): def test_diag_blockdiag_rewrite(): - n_matrices = 100 + n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) @@ -716,3 +716,38 @@ def test_det_blockdiag_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_slogdet_blockdiag_rewrite(): + n_matrices = 100 + matrix_size = (5, 5) + sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) + bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) + sign_output, logdet_output = pt.linalg.slogdet(bd_output) + f_rewritten = function( + [sub_matrices], [sign_output, logdet_output], mode="FAST_RUN" + ) + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) + + # Value Test + sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + bd_output_test = scipy.linalg.block_diag( + *[sub_matrices_test[i] for i in range(n_matrices)] + ) + sign_output_test, logdet_output_test = np.linalg.slogdet(bd_output_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(sub_matrices_test) + assert_allclose( + sign_output_test, + rewritten_sign_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + logdet_output_test, + rewritten_logdet_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From fa7d9daa2a0e0568a4b84086fca3852c3aa82add Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 19 Aug 2024 14:12:36 +0530 Subject: [PATCH 4/4] fixed typecasting for tests --- tests/tensor/rewriting/test_linalg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index db17a802a8..133e8d6a31 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -677,7 +677,7 @@ def test_diag_blockdiag_rewrite(): assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) # Value Test - sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) bd_output_test = scipy.linalg.block_diag( *[sub_matrices_test[i] for i in range(n_matrices)] ) @@ -704,7 +704,7 @@ def test_det_blockdiag_rewrite(): assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) # Value Test - sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) bd_output_test = scipy.linalg.block_diag( *[sub_matrices_test[i] for i in range(n_matrices)] ) @@ -733,7 +733,7 @@ def test_slogdet_blockdiag_rewrite(): assert not any(isinstance(node.op, BlockDiagonal) for node in nodes) # Value Test - sub_matrices_test = np.random.rand(n_matrices, *matrix_size) + sub_matrices_test = np.random.rand(n_matrices, *matrix_size).astype(config.floatX) bd_output_test = scipy.linalg.block_diag( *[sub_matrices_test[i] for i in range(n_matrices)] )