From bd1dbb156cd89f736607049b0ab89f6728f3b6ce Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 19 Oct 2024 21:43:26 +0530 Subject: [PATCH 01/16] Added alternative slogdet to return sign and logdet of det op --- pytensor/tensor/nlinalg.py | 5 +- pytensor/tensor/rewriting/linalg.py | 141 ++++++++++++++-------------- 2 files changed, 74 insertions(+), 72 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index e7093a82bd..a1015f51e8 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -266,7 +266,10 @@ def __str__(self): return "SLogDet" -slogdet = Blockwise(SLogDet()) +# slogdet = Blockwise(SLogDet()) +def slogdet(x): + det_val = det(x) + return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) class Eig(Op): diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a2418147cf..40b4ca02a0 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -34,7 +34,6 @@ inv, kron, pinv, - slogdet, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -785,43 +784,43 @@ def rewrite_det_blockdiag(fgraph, node): return [prod(det_sub_matrices)] -@register_canonicalize -@register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_blockdiag(fgraph, node): - """ - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - - slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - sub_matrices = potential_block_diag.inputs - sign_sub_matrices, logdet_sub_matrices = zip( - *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] - ) - - return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] +# @register_canonicalize +# @register_stabilize +# @node_rewriter([slogdet]) +# def rewrite_slogdet_blockdiag(fgraph, node): +# """ +# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + +# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) + +# Parameters +# ---------- +# fgraph: FunctionGraph +# Function graph being optimized +# node: Apply +# Node of the function graph to be optimized + +# Returns +# ------- +# list of Variable, optional +# List of optimized variables, or None if no optimization was performed +# """ +# # Check for inner block_diag operation +# potential_block_diag = node.inputs[0].owner +# if not ( +# potential_block_diag +# and isinstance(potential_block_diag.op, Blockwise) +# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) +# ): +# return None + +# # Find the composing sub_matrices +# sub_matrices = potential_block_diag.inputs +# sign_sub_matrices, logdet_sub_matrices = zip( +# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] +# ) + +# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] @register_canonicalize @@ -858,39 +857,39 @@ def rewrite_diag_kronecker(fgraph, node): return [outer_prod_as_vector] -@register_canonicalize -@register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_kronecker(fgraph, node): - """ - This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner kron operation - potential_kron = node.inputs[0].owner - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): - return None - - # Find the matrices - a, b = potential_kron.inputs - signs, logdets = zip(*[slogdet(a), slogdet(b)]) - sizes = [a.shape[-1], b.shape[-1]] - prod_sizes = prod(sizes, no_zeros_in_input=True) - signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] - logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] - - return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] +# @register_canonicalize +# @register_stabilize +# @node_rewriter([slogdet]) +# def rewrite_slogdet_kronecker(fgraph, node): +# """ +# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + +# Parameters +# ---------- +# fgraph: FunctionGraph +# Function graph being optimized +# node: Apply +# Node of the function graph to be optimized + +# Returns +# ------- +# list of Variable, optional +# List of optimized variables, or None if no optimization was performed +# """ +# # Check for inner kron operation +# potential_kron = node.inputs[0].owner +# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): +# return None + +# # Find the matrices +# a, b = potential_kron.inputs +# signs, logdets = zip(*[slogdet(a), slogdet(b)]) +# sizes = [a.shape[-1], b.shape[-1]] +# prod_sizes = prod(sizes, no_zeros_in_input=True) +# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] +# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + +# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] @register_canonicalize From 36ed08493efb0f066cf631b2f922d68f8b9f5608 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sun, 20 Oct 2024 02:45:41 +0530 Subject: [PATCH 02/16] removed rewrites for slogdet and added the same for det which will be used later --- pytensor/tensor/rewriting/linalg.py | 57 +++++++++++++-------------- tests/tensor/rewriting/test_linalg.py | 20 ++++------ 2 files changed, 35 insertions(+), 42 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 40b4ca02a0..396745fed1 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -857,39 +857,38 @@ def rewrite_diag_kronecker(fgraph, node): return [outer_prod_as_vector] -# @register_canonicalize -# @register_stabilize -# @node_rewriter([slogdet]) -# def rewrite_slogdet_kronecker(fgraph, node): -# """ -# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def rewrite_det_kronecker(fgraph, node): + """ + This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those -# Parameters -# ---------- -# fgraph: FunctionGraph -# Function graph being optimized -# node: Apply -# Node of the function graph to be optimized + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized -# Returns -# ------- -# list of Variable, optional -# List of optimized variables, or None if no optimization was performed -# """ -# # Check for inner kron operation -# potential_kron = node.inputs[0].owner -# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): -# return None + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None -# # Find the matrices -# a, b = potential_kron.inputs -# signs, logdets = zip(*[slogdet(a), slogdet(b)]) -# sizes = [a.shape[-1], b.shape[-1]] -# prod_sizes = prod(sizes, no_zeros_in_input=True) -# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] -# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + # Find the matrices + a, b = potential_kron.inputs + dets = [det(a), det(b)] + sizes = [a.shape[-1], b.shape[-1]] + prod_sizes = prod(sizes, no_zeros_in_input=True) + det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) -# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] + return [det_final] @register_canonicalize diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9dd2a247a8..9ff0ff76d9 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -776,11 +776,11 @@ def test_diag_kronecker_rewrite(): ) -def test_slogdet_kronecker_rewrite(): +def test_det_kronecker_rewrite(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) - sign_output, logdet_output = pt.linalg.slogdet(kron_prod) - f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN") + det_output = pt.linalg.det(kron_prod) + f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN") # Rewrite Test nodes = f_rewritten.maker.fgraph.apply_nodes @@ -789,17 +789,11 @@ def test_slogdet_kronecker_rewrite(): # Value Test a_test, b_test = np.random.rand(2, 20, 20) kron_prod_test = np.kron(a_test, b_test) - sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) - rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test) + det_output_test = np.linalg.det(kron_prod_test) + rewritten_det_val = f_rewritten(kron_prod_test) assert_allclose( - sign_output_test, - rewritten_sign_val, - atol=1e-3 if config.floatX == "float32" else 1e-8, - rtol=1e-3 if config.floatX == "float32" else 1e-8, - ) - assert_allclose( - logdet_output_test, - rewritten_logdet_val, + det_output_test, + rewritten_det_val, atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) From 7dee9773bb36fe7798bef38562ad50e0463d4727 Mon Sep 17 00:00:00 2001 From: Tanish Date: Wed, 30 Oct 2024 20:29:45 +0530 Subject: [PATCH 03/16] added specialised rewrite for slogdet --- pytensor/tensor/rewriting/linalg.py | 1022 +++++++++++++++++++++++++++ 1 file changed, 1022 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 396745fed1..f2eb54db84 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,6 +2,1028 @@ from collections.abc import Callable from typing import cast +from pytensor import Variable +from pytensor import tensor as pt +from pytensor.graph import Apply, FunctionGraph +from pytensor.graph.rewriting.basic import ( + copy_stack_trace, + node_rewriter, +) +from pytensor.scalar.basic import Abs, Log, Mul, Sign +from pytensor.tensor.basic import ( + AllocDiag, + ExtractDiag, + Eye, + TensorVariable, + concatenate, + diag, + diagonal, +) +from pytensor.tensor.blas import Dot22 +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod +from pytensor.tensor.nlinalg import ( + SVD, + KroneckerProduct, + MatrixInverse, + MatrixPinv, + SLogDet, + det, + inv, + kron, + pinv, + svd, +) +from pytensor.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + Solve, + SolveBase, + block_diag, + cholesky, + solve, + solve_triangular, +) + + +logger = logging.getLogger(__name__) +ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) + + +def is_matrix_transpose(x: TensorVariable) -> bool: + """Check if a variable corresponds to a transpose of the last two axes""" + node = x.owner + if ( + node + and isinstance(node.op, DimShuffle) + and not (node.op.drop or node.op.augment) + ): + [inp] = node.inputs + ndims = inp.type.ndim + if ndims < 2: + return False + transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) + return node.op.new_order == transpose_order + return False + + +@register_canonicalize +@node_rewriter([DimShuffle]) +def transinv_to_invtrans(fgraph, node): + if is_matrix_transpose(node.outputs[0]): + (A,) = node.inputs + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, MatrixInverse) + ): + (X,) = A.owner.inputs + return [A.owner.op(node.op(X))] + + +@register_stabilize +@node_rewriter([Dot, Dot22]) +def inv_as_solve(fgraph, node): + """ + This utilizes a boolean `symmetric` tag on the matrices. + """ + if isinstance(node.op, Dot | Dot22): + l, r = node.inputs + if ( + l.owner + and isinstance(l.owner.op, Blockwise) + and isinstance(l.owner.op.core_op, MatrixInverse) + ): + return [solve(l.owner.inputs[0], r)] + if ( + r.owner + and isinstance(r.owner.op, Blockwise) + and isinstance(r.owner.op.core_op, MatrixInverse) + ): + x = r.owner.inputs[0] + if getattr(x.tag, "symmetric", None) is True: + return [solve(x, (l.mT)).mT] + else: + return [solve((x.mT), (l.mT)).mT] + + +@register_stabilize +@register_canonicalize +@node_rewriter([Blockwise]) +def generic_solve_to_solve_triangular(fgraph, node): + """ + If any solve() is applied to the output of a cholesky op, then + replace it with a triangular solve. + + """ + if isinstance(node.op.core_op, Solve): + if node.op.core_op.assume_a == "gen": + A, b = node.inputs # result is solution Ax=b + if ( + A.owner + and isinstance(A.owner.op, Blockwise) + and isinstance(A.owner.op.core_op, Cholesky) + ): + if A.owner.op.core_op.lower: + return [ + solve_triangular( + A, b, lower=True, b_ndim=node.op.core_op.b_ndim + ) + ] + else: + return [ + solve_triangular( + A, b, lower=False, b_ndim=node.op.core_op.b_ndim + ) + ] + if is_matrix_transpose(A): + (A_T,) = A.owner.inputs + if ( + A_T.owner + and isinstance(A_T.owner.op, Blockwise) + and isinstance(A_T.owner.op, Cholesky) + ): + if A_T.owner.op.lower: + return [ + solve_triangular( + A, b, lower=False, b_ndim=node.op.core_op.b_ndim + ) + ] + else: + return [ + solve_triangular( + A, b, lower=True, b_ndim=node.op.core_op.b_ndim + ) + ] + + +@register_specialize +@node_rewriter([Blockwise]) +def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): + """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T + + `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. + """ + core_op = node.op.core_op + + if not isinstance(core_op, SolveBase): + return None + + if node.op.core_op.b_ndim != 1: + return None + + [a, b] = node.inputs + + # Check `b` is actually batched + if b.type.ndim == 1: + return None + + # Check `a` is a matrix (possibly with degenerate dims on the left) + a_bcast_batch_dims = a.type.broadcastable[:-2] + if not all(a_bcast_batch_dims): + return None + # We squeeze degenerate dims, any that are still needed will be introduced by the new_solve + elif len(a_bcast_batch_dims): + a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims)))) + + # Recreate solve Op with b_ndim=2 + props = core_op._props_dict() + props["b_ndim"] = 2 + new_core_op = type(core_op)(**props) + matrix_b_solve = Blockwise(new_core_op) + + # Ravel any batched dims + original_b_shape = tuple(b.shape) + if len(original_b_shape) > 2: + b = b.reshape((-1, original_b_shape[-1])) + + # Apply the rewrite + new_solve = matrix_b_solve(a, b.T).T + + # Unravel any batched dims + if len(original_b_shape) > 2: + new_solve = new_solve.reshape(original_b_shape) + + old_solve = node.outputs[0] + copy_stack_trace(old_solve, new_solve) + + return [new_solve] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([DimShuffle]) +def no_transpose_symmetric(fgraph, node): + if is_matrix_transpose(node.outputs[0]): + x = node.inputs[0] + if getattr(x.tag, "symmetric", None): + return [x] + + +@register_stabilize +@node_rewriter([Blockwise]) +def psd_solve_with_chol(fgraph, node): + """ + This utilizes a boolean `psd` tag on matrices. + """ + if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2: + A, b = node.inputs # result is solution Ax=b + if getattr(A.tag, "psd", None) is True: + L = cholesky(A) + # N.B. this can be further reduced to a yet-unwritten cho_solve Op + # __if__ no other Op makes use of the L matrix during the + # stabilization + Li_b = solve_triangular(L, b, lower=True, b_ndim=2) + x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) + return [x] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def cholesky_ldotlt(fgraph, node): + """ + rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, + or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. + + Also works with matmul. + + This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. + """ + if not isinstance(node.op.core_op, Cholesky): + return + + A = node.inputs[0] + if not ( + A.owner is not None + and ( + ( + isinstance(A.owner.op, Dot | Dot22) + # This rewrite only applies to matrix Dot + and A.owner.inputs[0].type.ndim == 2 + ) + or (A.owner.op == _matrix_matrix_matmul) + ) + ): + return + + l, r = A.owner.inputs + + # cholesky(dot(L,L.T)) case + if ( + getattr(l.tag, "lower_triangular", False) + and is_matrix_transpose(r) + and r.owner.inputs[0] == l + ): + if node.op.core_op.lower: + return [l] + return [r] + + # cholesky(dot(U.T,U)) case + if ( + getattr(r.tag, "upper_triangular", False) + and is_matrix_transpose(l) + and l.owner.inputs[0] == r + ): + if node.op.core_op.lower: + return [l] + return [r] + + +@register_stabilize +@register_specialize +@node_rewriter([det]) +def local_det_chol(fgraph, node): + """ + If we have det(X) and there is already an L=cholesky(X) + floating around, then we can use prod(diag(L)) to get the determinant. + + """ + (x,) = node.inputs + for cl, xpos in fgraph.clients[x]: + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): + L = cl.outputs[0] + return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_prod_sqr(fgraph, node): + """ + This utilizes a boolean `positive` tag on matrices. + """ + (x,) = node.inputs + if x.owner and isinstance(x.owner.op, Prod): + # we cannot always make this substitution because + # the prod might include negative terms + p = x.owner.inputs[0] + + # p is the matrix we're reducing with prod + if getattr(p.tag, "positive", None) is True: + return [log(p).sum(axis=x.owner.op.axis)] + + # TODO: have a reduction like prod and sum that simply + # returns the sign of the prod multiplication. + + +@register_specialize +@node_rewriter([Blockwise]) +def local_lift_through_linalg( + fgraph: FunctionGraph, node: Apply +) -> list[Variable] | None: + """ + Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops + that join matrices (KroneckerProduct, BlockDiagonal). + + This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix + operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker + product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This + reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component + matrices. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + + # TODO: Simplify this if we end up Blockwising KroneckerProduct + if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): + return None + + y = node.inputs[0] + outer_op = node.op + + if y.owner and ( + isinstance(y.owner.op, Blockwise) + and isinstance(y.owner.op.core_op, BlockDiagonal) + or isinstance(y.owner.op, KroneckerProduct) + ): + input_matrices = y.owner.inputs + + if isinstance(outer_op.core_op, MatrixInverse): + outer_f = cast(Callable, inv) + elif isinstance(outer_op.core_op, Cholesky): + outer_f = cast(Callable, cholesky) + elif isinstance(outer_op.core_op, MatrixPinv): + outer_f = cast(Callable, pinv) + else: + raise NotImplementedError # pragma: no cover + + inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] + + if isinstance(y.owner.op, KroneckerProduct): + return [kron(*inner_matrices)] + elif isinstance(y.owner.op.core_op, BlockDiagonal): + return [block_diag(*inner_matrices)] + else: + raise NotImplementedError # pragma: no cover + return None + + +def _find_diag_from_eye_mul(potential_mul_input): + # Check if the op is Elemwise and mul + if not ( + potential_mul_input.owner is not None + and isinstance(potential_mul_input.owner.op, Elemwise) + and isinstance(potential_mul_input.owner.op.scalar_op, Mul) + ): + return None + + # Find whether any of the inputs to mul is Eye + inputs_to_mul = potential_mul_input.owner.inputs + eye_input = [ + mul_input + for mul_input in inputs_to_mul + if mul_input.owner + and ( + isinstance(mul_input.owner.op, Eye) + or + # This whole condition checks if there is an Eye hiding inside a DimShuffle. + # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: + # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. + ( + isinstance(mul_input.owner.op, DimShuffle) + and ( + mul_input.owner.op.is_left_expand_dims + or mul_input.owner.op.is_right_expand_dims + ) + and mul_input.owner.inputs[0].owner is not None + and isinstance(mul_input.owner.inputs[0].owner.op, Eye) + ) + ) + ] + + if not eye_input: + return None + + eye_input = eye_input[0] + # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset + if isinstance(eye_input.owner.op, Eye) and ( + not Eye.is_offset_zero(eye_input.owner) + or eye_input.broadcastable[-2:] != (False, False) + ): + return None + + # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on). + # We have to look inside DimShuffle to decide if the rewrite can be applied + if isinstance(eye_input.owner.op, DimShuffle) and ( + eye_input.owner.op.is_left_expand_dims + or eye_input.owner.op.is_right_expand_dims + ): + inner_eye = eye_input.owner.inputs[0] + # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't + # degenerate + if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != ( + False, + False, + ): + return None + + # Get all non Eye inputs (scalars/matrices/vectors) + non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) + return eye_input, non_eye_inputs + + +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@node_rewriter([det]) +def rewrite_det_diag_to_prod_diag(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its + diagonal elements. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + inputs = node.inputs[0] + + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + diag_input = inputs.owner.inputs[0] + det_val = diag_input.prod(axis=-1) + return [det_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] + + # Checking if original x was scalar/vector/matrix + if non_eye_input.type.broadcastable[-2:] == (True, True): + # For scalar + det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0]) + elif non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix + det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1) + else: + # For vector + det_val = non_eye_input.prod(axis=(-1, -2)) + det_val = det_val.astype(node.outputs[0].type.dtype) + return [det_val] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Blockwise]) +def svd_uv_merge(fgraph, node): + """If we have more than one `SVD` `Op`s and at least one has keyword argument + `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere + and allow `pytensor` to re-use the decomposition outputs instead of recomputing. + """ + if not isinstance(node.op.core_op, SVD): + return + + (x,) = node.inputs + + if node.op.core_op.compute_uv: + # compute_uv=True returns [u, s, v]. + # if at least u or v is used, no need to rewrite this node. + if ( + len(fgraph.clients[node.outputs[0]]) > 0 + or len(fgraph.clients[node.outputs[2]]) > 0 + ): + return + + # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. + # First, iterate to see if there is an SVD Op that can be reused. + for cl, _ in fgraph.clients[x]: + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if not cl.op.core_op.compute_uv: + return { + node.outputs[1]: cl.outputs[0], + } + + # If no SVD reusable, return a new one. + return { + node.outputs[1]: svd( + x, full_matrices=node.op.core_op.full_matrices, compute_uv=False + ), + } + + else: + # compute_uv=False returns [s]. + # We want rewrite if there is another one with compute_uv=True. + # For this case, just reuse the `s` from the one with compute_uv=True. + for cl, _ in fgraph.clients[x]: + if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): + if cl.op.core_op.compute_uv and ( + len(fgraph.clients[cl.outputs[0]]) > 0 + or len(fgraph.clients[cl.outputs[2]]) > 0 + ): + return [cl.outputs[1]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_inv(fgraph, node): + """ + This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. + + Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check if its a valid inverse operation (either inv/pinv) + # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation + # If the outer operation is not a valid inverse, we do not apply this rewrite + if not isinstance(node.op.core_op, ALL_INVERSE_OPS): + return None + + potential_inner_inv = node.inputs[0].owner + if potential_inner_inv is None or potential_inner_inv.op is None: + return None + + # Check if inner op is blockwise and and possible inv + if not ( + potential_inner_inv + and isinstance(potential_inner_inv.op, Blockwise) + and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS) + ): + return None + return [potential_inner_inv.inputs[0]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_eye_to_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + core_op = node.op.core_op + if not (isinstance(core_op, ALL_INVERSE_OPS)): + return None + + # Check whether input to inverse is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + return [potential_eye] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. + This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + core_op = node.op.core_op + if not (isinstance(core_op, ALL_INVERSE_OPS)): + return None + + inputs = node.inputs[0] + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + inv_input = inputs.owner.inputs[0] + inv_val = pt.diag(1 / inv_input) + return [inv_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + non_eye_input = non_eye_inputs[0] + + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2) + non_eye_input = pt.shape_padaxis(non_eye_diag, -2) + + return [eye_input / non_eye_input] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_blockdiag(fgraph, node): + """ + This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. + + diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + submatrices = potential_block_diag.inputs + submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] + + return [concatenate(submatrices_diag)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def rewrite_det_blockdiag(fgraph, node): + """ + This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. + + det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner block_diag operation + potential_block_diag = node.inputs[0].owner + if not ( + potential_block_diag + and isinstance(potential_block_diag.op, Blockwise) + and isinstance(potential_block_diag.op.core_op, BlockDiagonal) + ): + return None + + # Find the composing sub_matrices + sub_matrices = potential_block_diag.inputs + det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] + + return [prod(det_sub_matrices)] + + +# @register_canonicalize +# @register_stabilize +# @node_rewriter([slogdet]) +# def rewrite_slogdet_blockdiag(fgraph, node): +# """ +# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + +# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) + +# Parameters +# ---------- +# fgraph: FunctionGraph +# Function graph being optimized +# node: Apply +# Node of the function graph to be optimized + +# Returns +# ------- +# list of Variable, optional +# List of optimized variables, or None if no optimization was performed +# """ +# # Check for inner block_diag operation +# potential_block_diag = node.inputs[0].owner +# if not ( +# potential_block_diag +# and isinstance(potential_block_diag.op, Blockwise) +# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) +# ): +# return None + +# # Find the composing sub_matrices +# sub_matrices = potential_block_diag.inputs +# sign_sub_matrices, logdet_sub_matrices = zip( +# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] +# ) + +# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_kronecker(fgraph, node): + """ + This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. + + diag(kron(a,b)) -> outer(diag(a), diag(b)) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + diag_a, diag_b = diag(a), diag(b) + outer_prod_as_vector = outer(diag_a, diag_b).flatten() + + return [outer_prod_as_vector] + + +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def rewrite_det_kronecker(fgraph, node): + """ + This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + dets = [det(a), det(b)] + sizes = [a.shape[-1], b.shape[-1]] + prod_sizes = prod(sizes, no_zeros_in_input=True) + det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) + + return [det_final] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_remove_useless_cholesky(fgraph, node): + """ + This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself + + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input to Cholesky is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and hasattr(potential_eye.owner.inputs[-1], "data") + and potential_eye.owner.inputs[-1].data.item() == 0 + ): + return None + return [potential_eye] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + [input] = node.inputs + # Check for use of pt.diag first + if ( + input.owner + and isinstance(input.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(input.owner) + ): + diag_input = input.owner.inputs[0] + cholesky_val = pt.diag(diag_input**0.5) + return [cholesky_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(input) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + [non_eye_input] = non_eye_inputs + + # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) + if eye_input.type.ndim > 2: + non_eye_input = pt.shape_padaxis(non_eye_input, -2) + + return [eye_input * (non_eye_input**0.5)] + + +# SLogDet Rewrites +def check_sign_det(node): + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sign)): + return False + + return True + + +def check_log_abs_det(node): + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): + return False + + potential_abs = node.inputs[0].owner + if not ( + isinstance(potential_abs.op, Elemwise) + and isinstance(potential_abs.op.scalar_op, Abs) + ): + return False + + return True + + +def check_log_det(node): + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): + return False + + return True + + +@node_rewriter(tracks=[det]) +def slogdet_specialization(fgraph, node): + x = node.inputs[0] + sign_det_x, slog_det_x = SLogDet()(x) + replacements = {} + for client in list(fgraph.clients.keys()): + # Check for sign(det) + if check_sign_det(client[0].owner): + replacements[client[0].owner.outputs[0]] = sign_det_x + + # Check for log(abs(det)) + elif check_log_abs_det(client[0].owner): + replacements[client[0].owner.outputs[0]] = slog_det_x + + # Check for log(det) + elif check_log_det(client[0].owner): + pass + # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) + + # Det is used directly for something else, don't rewrite to avoid computing two dets + else: + return None + + return replacements or None +import logging +from collections.abc import Callable +from typing import cast + from pytensor import Variable from pytensor import tensor as pt from pytensor.compile import optdb From 9ba4a9a37a6daec3e057fccf0362e1b30ec0c9e3 Mon Sep 17 00:00:00 2001 From: Tanish Date: Thu, 31 Oct 2024 14:12:30 +0530 Subject: [PATCH 04/16] updated checks for specialised rewrite --- pytensor/tensor/rewriting/linalg.py | 43 ++++++++++++----------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index f2eb54db84..0458b1a01b 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -968,30 +968,17 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): # SLogDet Rewrites -def check_sign_det(node): - if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sign)): +def check_log_abs_det(fgraph, client): + # First, we find abs + if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)): return False - return True - - -def check_log_abs_det(node): - if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): - return False - - potential_abs = node.inputs[0].owner - if not ( - isinstance(potential_abs.op, Elemwise) - and isinstance(potential_abs.op.scalar_op, Abs) - ): - return False - - return True - - -def check_log_det(node): - if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): - return False + # Check whether log is a client of abs + for client_2 in fgraph.clients[client.outputs[0]]: + if not ( + isinstance(client_2.op, Elemwise) and isinstance(client_2.op.scalar_op, Log) + ): + return False return True @@ -1001,17 +988,21 @@ def slogdet_specialization(fgraph, node): x = node.inputs[0] sign_det_x, slog_det_x = SLogDet()(x) replacements = {} - for client in list(fgraph.clients.keys()): + for client in fgraph.clients[node.outputs[0]]: # Check for sign(det) - if check_sign_det(client[0].owner): + if isinstance(client[0].op, Elemwise) and isinstance( + client[0].op.scalar_op, Sign + ): replacements[client[0].owner.outputs[0]] = sign_det_x # Check for log(abs(det)) - elif check_log_abs_det(client[0].owner): + elif check_log_abs_det(fgraph, client[0]): replacements[client[0].owner.outputs[0]] = slog_det_x # Check for log(det) - elif check_log_det(client[0].owner): + elif isinstance(client[0].op, Elemwise) and isinstance( + client[0].op.scalar_op, Log + ): pass # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) From 51038efff7914d77ce2103b475d5cc5de356e6e6 Mon Sep 17 00:00:00 2001 From: Tanish Date: Thu, 31 Oct 2024 14:30:36 +0530 Subject: [PATCH 05/16] fixing bad rebase --- pytensor/tensor/rewriting/linalg.py | 989 ---------------------------- 1 file changed, 989 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 0458b1a01b..95e8fb4d64 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1011,992 +1011,3 @@ def slogdet_specialization(fgraph, node): return None return replacements or None -import logging -from collections.abc import Callable -from typing import cast - -from pytensor import Variable -from pytensor import tensor as pt -from pytensor.compile import optdb -from pytensor.graph import Apply, FunctionGraph -from pytensor.graph.rewriting.basic import ( - copy_stack_trace, - in2out, - node_rewriter, -) -from pytensor.scalar.basic import Mul -from pytensor.tensor.basic import ( - AllocDiag, - ExtractDiag, - Eye, - TensorVariable, - concatenate, - diag, - diagonal, -) -from pytensor.tensor.blas import Dot22 -from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod -from pytensor.tensor.nlinalg import ( - SVD, - KroneckerProduct, - MatrixInverse, - MatrixPinv, - det, - inv, - kron, - pinv, - svd, -) -from pytensor.tensor.rewriting.basic import ( - register_canonicalize, - register_specialize, - register_stabilize, -) -from pytensor.tensor.slinalg import ( - BlockDiagonal, - Cholesky, - Solve, - SolveBase, - _bilinear_solve_discrete_lyapunov, - block_diag, - cholesky, - solve, - solve_discrete_lyapunov, - solve_triangular, -) - - -logger = logging.getLogger(__name__) -ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) - - -def is_matrix_transpose(x: TensorVariable) -> bool: - """Check if a variable corresponds to a transpose of the last two axes""" - node = x.owner - if ( - node - and isinstance(node.op, DimShuffle) - and not (node.op.drop or node.op.augment) - ): - [inp] = node.inputs - ndims = inp.type.ndim - if ndims < 2: - return False - transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) - return node.op.new_order == transpose_order - return False - - -@register_canonicalize -@node_rewriter([DimShuffle]) -def transinv_to_invtrans(fgraph, node): - if is_matrix_transpose(node.outputs[0]): - (A,) = node.inputs - if ( - A.owner - and isinstance(A.owner.op, Blockwise) - and isinstance(A.owner.op.core_op, MatrixInverse) - ): - (X,) = A.owner.inputs - return [A.owner.op(node.op(X))] - - -@register_stabilize -@node_rewriter([Dot, Dot22]) -def inv_as_solve(fgraph, node): - """ - This utilizes a boolean `symmetric` tag on the matrices. - """ - if isinstance(node.op, Dot | Dot22): - l, r = node.inputs - if ( - l.owner - and isinstance(l.owner.op, Blockwise) - and isinstance(l.owner.op.core_op, MatrixInverse) - ): - return [solve(l.owner.inputs[0], r)] - if ( - r.owner - and isinstance(r.owner.op, Blockwise) - and isinstance(r.owner.op.core_op, MatrixInverse) - ): - x = r.owner.inputs[0] - if getattr(x.tag, "symmetric", None) is True: - return [solve(x, (l.mT)).mT] - else: - return [solve((x.mT), (l.mT)).mT] - - -@register_stabilize -@register_canonicalize -@node_rewriter([Blockwise]) -def generic_solve_to_solve_triangular(fgraph, node): - """ - If any solve() is applied to the output of a cholesky op, then - replace it with a triangular solve. - - """ - if isinstance(node.op.core_op, Solve): - if node.op.core_op.assume_a == "gen": - A, b = node.inputs # result is solution Ax=b - if ( - A.owner - and isinstance(A.owner.op, Blockwise) - and isinstance(A.owner.op.core_op, Cholesky) - ): - if A.owner.op.core_op.lower: - return [ - solve_triangular( - A, b, lower=True, b_ndim=node.op.core_op.b_ndim - ) - ] - else: - return [ - solve_triangular( - A, b, lower=False, b_ndim=node.op.core_op.b_ndim - ) - ] - if is_matrix_transpose(A): - (A_T,) = A.owner.inputs - if ( - A_T.owner - and isinstance(A_T.owner.op, Blockwise) - and isinstance(A_T.owner.op, Cholesky) - ): - if A_T.owner.op.lower: - return [ - solve_triangular( - A, b, lower=False, b_ndim=node.op.core_op.b_ndim - ) - ] - else: - return [ - solve_triangular( - A, b, lower=True, b_ndim=node.op.core_op.b_ndim - ) - ] - - -@register_specialize -@node_rewriter([Blockwise]) -def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): - """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T - - `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. - """ - core_op = node.op.core_op - - if not isinstance(core_op, SolveBase): - return None - - if node.op.core_op.b_ndim != 1: - return None - - [a, b] = node.inputs - - # Check `b` is actually batched - if b.type.ndim == 1: - return None - - # Check `a` is a matrix (possibly with degenerate dims on the left) - a_bcast_batch_dims = a.type.broadcastable[:-2] - if not all(a_bcast_batch_dims): - return None - # We squeeze degenerate dims, any that are still needed will be introduced by the new_solve - elif len(a_bcast_batch_dims): - a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims)))) - - # Recreate solve Op with b_ndim=2 - props = core_op._props_dict() - props["b_ndim"] = 2 - new_core_op = type(core_op)(**props) - matrix_b_solve = Blockwise(new_core_op) - - # Ravel any batched dims - original_b_shape = tuple(b.shape) - if len(original_b_shape) > 2: - b = b.reshape((-1, original_b_shape[-1])) - - # Apply the rewrite - new_solve = matrix_b_solve(a, b.T).T - - # Unravel any batched dims - if len(original_b_shape) > 2: - new_solve = new_solve.reshape(original_b_shape) - - old_solve = node.outputs[0] - copy_stack_trace(old_solve, new_solve) - - return [new_solve] - - -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([DimShuffle]) -def no_transpose_symmetric(fgraph, node): - if is_matrix_transpose(node.outputs[0]): - x = node.inputs[0] - if getattr(x.tag, "symmetric", None): - return [x] - - -@register_stabilize -@node_rewriter([Blockwise]) -def psd_solve_with_chol(fgraph, node): - """ - This utilizes a boolean `psd` tag on matrices. - """ - if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2: - A, b = node.inputs # result is solution Ax=b - if getattr(A.tag, "psd", None) is True: - L = cholesky(A) - # N.B. this can be further reduced to a yet-unwritten cho_solve Op - # __if__ no other Op makes use of the L matrix during the - # stabilization - Li_b = solve_triangular(L, b, lower=True, b_ndim=2) - x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) - return [x] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def cholesky_ldotlt(fgraph, node): - """ - rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular, - or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular. - - Also works with matmul. - - This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. - """ - if not isinstance(node.op.core_op, Cholesky): - return - - A = node.inputs[0] - if not ( - A.owner is not None - and ( - ( - isinstance(A.owner.op, Dot | Dot22) - # This rewrite only applies to matrix Dot - and A.owner.inputs[0].type.ndim == 2 - ) - or (A.owner.op == _matrix_matrix_matmul) - ) - ): - return - - l, r = A.owner.inputs - - # cholesky(dot(L,L.T)) case - if ( - getattr(l.tag, "lower_triangular", False) - and is_matrix_transpose(r) - and r.owner.inputs[0] == l - ): - if node.op.core_op.lower: - return [l] - return [r] - - # cholesky(dot(U.T,U)) case - if ( - getattr(r.tag, "upper_triangular", False) - and is_matrix_transpose(l) - and l.owner.inputs[0] == r - ): - if node.op.core_op.lower: - return [l] - return [r] - - -@register_stabilize -@register_specialize -@node_rewriter([det]) -def local_det_chol(fgraph, node): - """ - If we have det(X) and there is already an L=cholesky(X) - floating around, then we can use prod(diag(L)) to get the determinant. - - """ - (x,) = node.inputs - for cl, xpos in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): - L = cl.outputs[0] - return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] - - -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([log]) -def local_log_prod_sqr(fgraph, node): - """ - This utilizes a boolean `positive` tag on matrices. - """ - (x,) = node.inputs - if x.owner and isinstance(x.owner.op, Prod): - # we cannot always make this substitution because - # the prod might include negative terms - p = x.owner.inputs[0] - - # p is the matrix we're reducing with prod - if getattr(p.tag, "positive", None) is True: - return [log(p).sum(axis=x.owner.op.axis)] - - # TODO: have a reduction like prod and sum that simply - # returns the sign of the prod multiplication. - - -@register_specialize -@node_rewriter([Blockwise]) -def local_lift_through_linalg( - fgraph: FunctionGraph, node: Apply -) -> list[Variable] | None: - """ - Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops - that join matrices (KroneckerProduct, BlockDiagonal). - - This rewrite takes advantage of commutation between certain linear algebra operations to do several smaller matrix - operations on component matrices instead of one large one. For example, when taking the inverse of Kronecker - product, we can take the inverse of each component matrix and then take the Kronecker product of the inverses. This - reduces the cost of the inverse from O((n*m)^3) to O(n^3 + m^3) where n and m are the dimensions of the component - matrices. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - - # TODO: Simplify this if we end up Blockwising KroneckerProduct - if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv): - return None - - y = node.inputs[0] - outer_op = node.op - - if y.owner and ( - isinstance(y.owner.op, Blockwise) - and isinstance(y.owner.op.core_op, BlockDiagonal) - or isinstance(y.owner.op, KroneckerProduct) - ): - input_matrices = y.owner.inputs - - if isinstance(outer_op.core_op, MatrixInverse): - outer_f = cast(Callable, inv) - elif isinstance(outer_op.core_op, Cholesky): - outer_f = cast(Callable, cholesky) - elif isinstance(outer_op.core_op, MatrixPinv): - outer_f = cast(Callable, pinv) - else: - raise NotImplementedError # pragma: no cover - - inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] - - if isinstance(y.owner.op, KroneckerProduct): - return [kron(*inner_matrices)] - elif isinstance(y.owner.op.core_op, BlockDiagonal): - return [block_diag(*inner_matrices)] - else: - raise NotImplementedError # pragma: no cover - return None - - -def _find_diag_from_eye_mul(potential_mul_input): - # Check if the op is Elemwise and mul - if not ( - potential_mul_input.owner is not None - and isinstance(potential_mul_input.owner.op, Elemwise) - and isinstance(potential_mul_input.owner.op.scalar_op, Mul) - ): - return None - - # Find whether any of the inputs to mul is Eye - inputs_to_mul = potential_mul_input.owner.inputs - eye_input = [ - mul_input - for mul_input in inputs_to_mul - if mul_input.owner - and ( - isinstance(mul_input.owner.op, Eye) - or - # This whole condition checks if there is an Eye hiding inside a DimShuffle. - # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: - # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. - ( - isinstance(mul_input.owner.op, DimShuffle) - and ( - mul_input.owner.op.is_left_expand_dims - or mul_input.owner.op.is_right_expand_dims - ) - and mul_input.owner.inputs[0].owner is not None - and isinstance(mul_input.owner.inputs[0].owner.op, Eye) - ) - ) - ] - - if not eye_input: - return None - - eye_input = eye_input[0] - # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset - if isinstance(eye_input.owner.op, Eye) and ( - not Eye.is_offset_zero(eye_input.owner) - or eye_input.broadcastable[-2:] != (False, False) - ): - return None - - # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on). - # We have to look inside DimShuffle to decide if the rewrite can be applied - if isinstance(eye_input.owner.op, DimShuffle) and ( - eye_input.owner.op.is_left_expand_dims - or eye_input.owner.op.is_right_expand_dims - ): - inner_eye = eye_input.owner.inputs[0] - # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't - # degenerate - if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != ( - False, - False, - ): - return None - - # Get all non Eye inputs (scalars/matrices/vectors) - non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) - return eye_input, non_eye_inputs - - -@register_canonicalize("shape_unsafe") -@register_stabilize("shape_unsafe") -@node_rewriter([det]) -def rewrite_det_diag_to_prod_diag(fgraph, node): - """ - This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its - diagonal elements. - - The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices - that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to - make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, - vector or a matrix. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - inputs = node.inputs[0] - - # Check for use of pt.diag first - if ( - inputs.owner - and isinstance(inputs.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(inputs.owner) - ): - diag_input = inputs.owner.inputs[0] - det_val = diag_input.prod(axis=-1) - return [det_val] - - # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(inputs) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] - - # Checking if original x was scalar/vector/matrix - if non_eye_input.type.broadcastable[-2:] == (True, True): - # For scalar - det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0]) - elif non_eye_input.type.broadcastable[-2:] == (False, False): - # For Matrix - det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1) - else: - # For vector - det_val = non_eye_input.prod(axis=(-1, -2)) - det_val = det_val.astype(node.outputs[0].type.dtype) - return [det_val] - - -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([Blockwise]) -def svd_uv_merge(fgraph, node): - """If we have more than one `SVD` `Op`s and at least one has keyword argument - `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere - and allow `pytensor` to re-use the decomposition outputs instead of recomputing. - """ - if not isinstance(node.op.core_op, SVD): - return - - (x,) = node.inputs - - if node.op.core_op.compute_uv: - # compute_uv=True returns [u, s, v]. - # if at least u or v is used, no need to rewrite this node. - if ( - len(fgraph.clients[node.outputs[0]]) > 0 - or len(fgraph.clients[node.outputs[2]]) > 0 - ): - return - - # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. - # First, iterate to see if there is an SVD Op that can be reused. - for cl, _ in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if not cl.op.core_op.compute_uv: - return { - node.outputs[1]: cl.outputs[0], - } - - # If no SVD reusable, return a new one. - return { - node.outputs[1]: svd( - x, full_matrices=node.op.core_op.full_matrices, compute_uv=False - ), - } - - else: - # compute_uv=False returns [s]. - # We want rewrite if there is another one with compute_uv=True. - # For this case, just reuse the `s` from the one with compute_uv=True. - for cl, _ in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if cl.op.core_op.compute_uv and ( - len(fgraph.clients[cl.outputs[0]]) > 0 - or len(fgraph.clients[cl.outputs[2]]) > 0 - ): - return [cl.outputs[1]] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def rewrite_inv_inv(fgraph, node): - """ - This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. - - Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check if its a valid inverse operation (either inv/pinv) - # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation - # If the outer operation is not a valid inverse, we do not apply this rewrite - if not isinstance(node.op.core_op, ALL_INVERSE_OPS): - return None - - potential_inner_inv = node.inputs[0].owner - if potential_inner_inv is None or potential_inner_inv.op is None: - return None - - # Check if inner op is blockwise and and possible inv - if not ( - potential_inner_inv - and isinstance(potential_inner_inv.op, Blockwise) - and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS) - ): - return None - return [potential_inner_inv.inputs[0]] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def rewrite_inv_eye_to_eye(fgraph, node): - """ - This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself - The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - core_op = node.op.core_op - if not (isinstance(core_op, ALL_INVERSE_OPS)): - return None - - # Check whether input to inverse is Eye and the 1's are on main diagonal - potential_eye = node.inputs[0] - if not ( - potential_eye.owner - and isinstance(potential_eye.owner.op, Eye) - and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 - ): - return None - return [potential_eye] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): - """ - This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. - This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - core_op = node.op.core_op - if not (isinstance(core_op, ALL_INVERSE_OPS)): - return None - - inputs = node.inputs[0] - # Check for use of pt.diag first - if ( - inputs.owner - and isinstance(inputs.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(inputs.owner) - ): - inv_input = inputs.owner.inputs[0] - inv_val = pt.diag(1 / inv_input) - return [inv_val] - - # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(inputs) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - non_eye_input = non_eye_inputs[0] - - # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those - if non_eye_input.type.broadcastable[-2:] == (False, False): - non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2) - non_eye_input = pt.shape_padaxis(non_eye_diag, -2) - - return [eye_input / non_eye_input] - - -@register_canonicalize -@register_stabilize -@node_rewriter([ExtractDiag]) -def rewrite_diag_blockdiag(fgraph, node): - """ - This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. - - diag(block_diag(a,b,c,....)) = concat(diag(a), diag(b), diag(c),...) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - submatrices = potential_block_diag.inputs - submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] - - return [concatenate(submatrices_diag)] - - -@register_canonicalize -@register_stabilize -@node_rewriter([det]) -def rewrite_det_blockdiag(fgraph, node): - """ - This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. - - det(block_diag(a,b,c,....)) = prod(det(a), det(b), det(c),...) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - sub_matrices = potential_block_diag.inputs - det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] - - return [prod(det_sub_matrices)] - - -# @register_canonicalize -# @register_stabilize -# @node_rewriter([slogdet]) -# def rewrite_slogdet_blockdiag(fgraph, node): -# """ -# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - -# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) - -# Parameters -# ---------- -# fgraph: FunctionGraph -# Function graph being optimized -# node: Apply -# Node of the function graph to be optimized - -# Returns -# ------- -# list of Variable, optional -# List of optimized variables, or None if no optimization was performed -# """ -# # Check for inner block_diag operation -# potential_block_diag = node.inputs[0].owner -# if not ( -# potential_block_diag -# and isinstance(potential_block_diag.op, Blockwise) -# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) -# ): -# return None - -# # Find the composing sub_matrices -# sub_matrices = potential_block_diag.inputs -# sign_sub_matrices, logdet_sub_matrices = zip( -# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] -# ) - -# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] - - -@register_canonicalize -@register_stabilize -@node_rewriter([ExtractDiag]) -def rewrite_diag_kronecker(fgraph, node): - """ - This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. - - diag(kron(a,b)) -> outer(diag(a), diag(b)) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner kron operation - potential_kron = node.inputs[0].owner - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): - return None - - # Find the matrices - a, b = potential_kron.inputs - diag_a, diag_b = diag(a), diag(b) - outer_prod_as_vector = outer(diag_a, diag_b).flatten() - - return [outer_prod_as_vector] - - -@register_canonicalize -@register_stabilize -@node_rewriter([det]) -def rewrite_det_kronecker(fgraph, node): - """ - This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner kron operation - potential_kron = node.inputs[0].owner - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): - return None - - # Find the matrices - a, b = potential_kron.inputs - dets = [det(a), det(b)] - sizes = [a.shape[-1], b.shape[-1]] - prod_sizes = prod(sizes, no_zeros_in_input=True) - det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) - - return [det_final] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def rewrite_remove_useless_cholesky(fgraph, node): - """ - This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself - - The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Find whether cholesky op is being applied - if not isinstance(node.op.core_op, Cholesky): - return None - - # Check whether input to Cholesky is Eye and the 1's are on main diagonal - potential_eye = node.inputs[0] - if not ( - potential_eye.owner - and isinstance(potential_eye.owner.op, Eye) - and hasattr(potential_eye.owner.inputs[-1], "data") - and potential_eye.owner.inputs[-1].data.item() == 0 - ): - return None - return [potential_eye] - - -@register_canonicalize -@register_stabilize -@node_rewriter([Blockwise]) -def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): - # Find whether cholesky op is being applied - if not isinstance(node.op.core_op, Cholesky): - return None - - [input] = node.inputs - # Check for use of pt.diag first - if ( - input.owner - and isinstance(input.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(input.owner) - ): - diag_input = input.owner.inputs[0] - cholesky_val = pt.diag(diag_input**0.5) - return [cholesky_val] - - # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(input) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - [non_eye_input] = non_eye_inputs - - # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements - # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those - if non_eye_input.type.broadcastable[-2:] == (False, False): - non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) - if eye_input.type.ndim > 2: - non_eye_input = pt.shape_padaxis(non_eye_input, -2) - - return [eye_input * (non_eye_input**0.5)] - - -@node_rewriter([_bilinear_solve_discrete_lyapunov]) -def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): - """ - Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX - """ - A, B = (cast(TensorVariable, x) for x in node.inputs) - result = solve_discrete_lyapunov(A, B, method="direct") - - return [result] - - -optdb.register( - "jax_bilinaer_lyapunov_to_direct", - in2out(jax_bilinaer_lyapunov_to_direct), - "jax", - position=0.9, # Run before canonicalization -) From fa4ed0b93846c449a0ce919d0a5835ce4e1bde28 Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 1 Nov 2024 00:00:52 +0530 Subject: [PATCH 06/16] updated specialisation rewrite --- pytensor/tensor/rewriting/linalg.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 95e8fb4d64..864f952dbf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -985,26 +985,30 @@ def check_log_abs_det(fgraph, client): @node_rewriter(tracks=[det]) def slogdet_specialization(fgraph, node): - x = node.inputs[0] - sign_det_x, slog_det_x = SLogDet()(x) replacements = {} for client in fgraph.clients[node.outputs[0]]: # Check for sign(det) if isinstance(client[0].op, Elemwise) and isinstance( client[0].op.scalar_op, Sign ): - replacements[client[0].owner.outputs[0]] = sign_det_x + x = node.inputs[0] + sign_det_x, slog_det_x = SLogDet()(x) + replacements[client[0].outputs[0]] = sign_det_x # Check for log(abs(det)) elif check_log_abs_det(fgraph, client[0]): - replacements[client[0].owner.outputs[0]] = slog_det_x + x = node.inputs[0] + sign_det_x, slog_det_x = SLogDet()(x) + replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = ( + slog_det_x + ) # Check for log(det) - elif isinstance(client[0].op, Elemwise) and isinstance( - client[0].op.scalar_op, Log - ): - pass - # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) + # elif isinstance(client[0].op, Elemwise) and isinstance( + # client[0].op.scalar_op, Log + # ): + # pass + # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) # Det is used directly for something else, don't rewrite to avoid computing two dets else: From 03215e82fb6b884e25871033f40d90ecb8ddea79 Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 1 Nov 2024 15:34:57 +0530 Subject: [PATCH 07/16] working specialised rewrite + test --- pytensor/tensor/rewriting/linalg.py | 26 ++++++++++++++++---------- tests/tensor/rewriting/test_linalg.py | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 864f952dbf..53e0d8cfdf 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,6 +2,8 @@ from collections.abc import Callable from typing import cast +import numpy as np + from pytensor import Variable from pytensor import tensor as pt from pytensor.graph import Apply, FunctionGraph @@ -967,8 +969,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): return [eye_input * (non_eye_input**0.5)] -# SLogDet Rewrites -def check_log_abs_det(fgraph, client): +def _check_log_abs_det(fgraph, client): # First, we find abs if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)): return False @@ -976,14 +977,16 @@ def check_log_abs_det(fgraph, client): # Check whether log is a client of abs for client_2 in fgraph.clients[client.outputs[0]]: if not ( - isinstance(client_2.op, Elemwise) and isinstance(client_2.op.scalar_op, Log) + isinstance(client_2[0].op, Elemwise) + and isinstance(client_2[0].op.scalar_op, Log) ): return False return True -@node_rewriter(tracks=[det]) +@register_specialize +@node_rewriter([det]) def slogdet_specialization(fgraph, node): replacements = {} for client in fgraph.clients[node.outputs[0]]: @@ -996,7 +999,7 @@ def slogdet_specialization(fgraph, node): replacements[client[0].outputs[0]] = sign_det_x # Check for log(abs(det)) - elif check_log_abs_det(fgraph, client[0]): + elif _check_log_abs_det(fgraph, client[0]): x = node.inputs[0] sign_det_x, slog_det_x = SLogDet()(x) replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = ( @@ -1004,11 +1007,14 @@ def slogdet_specialization(fgraph, node): ) # Check for log(det) - # elif isinstance(client[0].op, Elemwise) and isinstance( - # client[0].op.scalar_op, Log - # ): - # pass - # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) + elif isinstance(client[0].op, Elemwise) and isinstance( + client[0].op.scalar_op, Log + ): + x = node.inputs[0] + sign_det_x, slog_det_x = SLogDet()(x) + replacements[client[0].outputs[0]] = pt.where( + pt.eq(sign_det_x, -1), np.nan, slog_det_x + ) # Det is used directly for something else, don't rewrite to avoid computing two dets else: diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9ff0ff76d9..7658bc2404 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -21,6 +21,7 @@ KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, matrix_inverse, svd, ) @@ -900,3 +901,23 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): f_rewritten = function([x], z_cholesky, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) + + +def test_slogdet_specialisation(): + x = pt.dmatrix("x") + det_x = pt.linalg.det(x) + log_abs_det_x = pt.log(pt.abs(det_x)) + sign_det_x = pt.sign(det_x) + exp_det_x = pt.exp(det_x) + # sign(det(x)) + f = function([x], [sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert any(isinstance(node.op, SLogDet) for node in nodes) + # log(abs(det(x))) + f = function([x], [log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert any(isinstance(node.op, SLogDet) for node in nodes) + # other functions (rewrite shouldnt be applied to these) + f = function([x], [exp_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes) From 6927516c095c51d2100941ab849651f9c59884d9 Mon Sep 17 00:00:00 2001 From: Tanish Date: Thu, 14 Nov 2024 23:58:39 +0530 Subject: [PATCH 08/16] added all tests --- pytensor/tensor/nlinalg.py | 1 - pytensor/tensor/rewriting/linalg.py | 46 +++++---------- tests/tensor/rewriting/test_linalg.py | 85 ++++++++++++++++++++++++--- 3 files changed, 92 insertions(+), 40 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index a1015f51e8..30b48d2756 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -266,7 +266,6 @@ def __str__(self): return "SLogDet" -# slogdet = Blockwise(SLogDet()) def slogdet(x): det_val = det(x) return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 53e0d8cfdf..7afae88ccc 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -969,50 +969,34 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): return [eye_input * (non_eye_input**0.5)] -def _check_log_abs_det(fgraph, client): - # First, we find abs - if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)): - return False - - # Check whether log is a client of abs - for client_2 in fgraph.clients[client.outputs[0]]: - if not ( - isinstance(client_2[0].op, Elemwise) - and isinstance(client_2[0].op.scalar_op, Log) - ): - return False - - return True - - @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): replacements = {} - for client in fgraph.clients[node.outputs[0]]: + for client, _ in fgraph.clients[node.outputs[0]]: # Check for sign(det) - if isinstance(client[0].op, Elemwise) and isinstance( - client[0].op.scalar_op, Sign - ): + if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign): x = node.inputs[0] sign_det_x, slog_det_x = SLogDet()(x) - replacements[client[0].outputs[0]] = sign_det_x + replacements[client.outputs[0]] = sign_det_x # Check for log(abs(det)) - elif _check_log_abs_det(fgraph, client[0]): - x = node.inputs[0] - sign_det_x, slog_det_x = SLogDet()(x) - replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = ( - slog_det_x - ) + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): + for client_2, _ in fgraph.clients[client.outputs[0]]: + if isinstance(client_2.op, Elemwise) and isinstance( + client_2.op.scalar_op, Log + ): + x = node.inputs[0] + sign_det_x, slog_det_x = SLogDet()(x) + replacements[fgraph.clients[client.outputs[0]][0][0].outputs[0]] = ( + slog_det_x + ) # Check for log(det) - elif isinstance(client[0].op, Elemwise) and isinstance( - client[0].op.scalar_op, Log - ): + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): x = node.inputs[0] sign_det_x, slog_det_x = SLogDet()(x) - replacements[client[0].outputs[0]] = pt.where( + replacements[client.outputs[0]] = pt.where( pt.eq(sign_det_x, -1), np.nan, slog_det_x ) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 7658bc2404..3a1f84454f 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -781,7 +781,7 @@ def test_det_kronecker_rewrite(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) det_output = pt.linalg.det(kron_prod) - f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN") + f_rewritten = function([a, b], [det_output], mode="FAST_RUN") # Rewrite Test nodes = f_rewritten.maker.fgraph.apply_nodes @@ -791,7 +791,7 @@ def test_det_kronecker_rewrite(): a_test, b_test = np.random.rand(2, 20, 20) kron_prod_test = np.kron(a_test, b_test) det_output_test = np.linalg.det(kron_prod_test) - rewritten_det_val = f_rewritten(kron_prod_test) + rewritten_det_val = f_rewritten(a_test, b_test) assert_allclose( det_output_test, rewritten_det_val, @@ -800,6 +800,35 @@ def test_det_kronecker_rewrite(): ) +def test_slogdet_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + sign_output, logdet_output = pt.linalg.slogdet(kron_prod) + f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test) + assert_allclose( + sign_output_test, + rewritten_sign_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + logdet_output_test, + rewritten_logdet_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + def test_cholesky_eye_rewrite(): x = pt.eye(10) L = pt.linalg.cholesky(x) @@ -904,20 +933,60 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): def test_slogdet_specialisation(): - x = pt.dmatrix("x") - det_x = pt.linalg.det(x) - log_abs_det_x = pt.log(pt.abs(det_x)) - sign_det_x = pt.sign(det_x) + x, a = pt.dmatrix("x"), np.random.rand(20, 20) + det_x, det_a = pt.linalg.det(x), np.linalg.det(a) + log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a)) + log_det_x, log_det_a = pt.log(det_x), np.log(det_a) + sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a) exp_det_x = pt.exp(det_x) + # REWRITE TESTS # sign(det(x)) f = function([x], [sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes - assert any(isinstance(node.op, SLogDet) for node in nodes) + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_sign_det_a = f(a) + assert_allclose( + sign_det_a, + rw_sign_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) # log(abs(det(x))) f = function([x], [log_abs_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes - assert any(isinstance(node.op, SLogDet) for node in nodes) + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_abs_det_a = f(a) + assert_allclose( + log_abs_det_a, + rw_log_abs_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + # log(det(x)) + f = function([x], [log_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_det_a = f(a) + assert_allclose( + log_det_a, + rw_log_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + # more than 1 valid function + f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) # other functions (rewrite shouldnt be applied to these) + # only invalid functions f = function([x], [exp_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) + # invalid + valid function + f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes) From 8f6badf06db5e9d439567c9602797783485c4427 Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 15 Nov 2024 02:03:52 +0530 Subject: [PATCH 09/16] splitting rewrite into 2 stages done --- pytensor/tensor/rewriting/linalg.py | 31 ++++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 7afae88ccc..d79e15bbfc 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -972,13 +972,11 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): - replacements = {} + dummy_replacements = {} for client, _ in fgraph.clients[node.outputs[0]]: # Check for sign(det) if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign): - x = node.inputs[0] - sign_det_x, slog_det_x = SLogDet()(x) - replacements[client.outputs[0]] = sign_det_x + dummy_replacements[client.outputs[0]] = "sign" # Check for log(abs(det)) elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): @@ -986,22 +984,27 @@ def slogdet_specialization(fgraph, node): if isinstance(client_2.op, Elemwise) and isinstance( client_2.op.scalar_op, Log ): - x = node.inputs[0] - sign_det_x, slog_det_x = SLogDet()(x) - replacements[fgraph.clients[client.outputs[0]][0][0].outputs[0]] = ( - slog_det_x - ) + dummy_replacements[ + fgraph.clients[client.outputs[0]][0][0].outputs[0] + ] = "log_abs_det" # Check for log(det) elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): - x = node.inputs[0] - sign_det_x, slog_det_x = SLogDet()(x) - replacements[client.outputs[0]] = pt.where( - pt.eq(sign_det_x, -1), np.nan, slog_det_x - ) + dummy_replacements[client.outputs[0]] = "log_det" # Det is used directly for something else, don't rewrite to avoid computing two dets else: return None + [x] = node.inputs + sign_det_x, log_abs_det_x = SLogDet()(x) + log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) + slogdet_specialization_map = { + "sign": sign_det_x, + "log_abs_det": log_abs_det_x, + "log_det": log_det_x, + } + replacements = { + k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() + } return replacements or None From 69765a198309fec2317ee93d17f16ec07ea67c64 Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 15 Nov 2024 13:58:12 +0530 Subject: [PATCH 10/16] fixed failing precision test --- tests/tensor/rewriting/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 3a1f84454f..fdcfc3ec91 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -720,7 +720,7 @@ def test_det_blockdiag_rewrite(): def test_slogdet_blockdiag_rewrite(): - n_matrices = 100 + n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) From 60e34fee3eecb8f1b7ed812371e48a943b0c978b Mon Sep 17 00:00:00 2001 From: Tanish Date: Fri, 15 Nov 2024 16:41:55 +0530 Subject: [PATCH 11/16] added missing function becaues of rebase --- pytensor/tensor/rewriting/linalg.py | 62 +++++++++++------------------ 1 file changed, 23 insertions(+), 39 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index d79e15bbfc..c6d3416334 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -6,9 +6,11 @@ from pytensor import Variable from pytensor import tensor as pt +from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, + in2out, node_rewriter, ) from pytensor.scalar.basic import Abs, Log, Mul, Sign @@ -47,9 +49,11 @@ Cholesky, Solve, SolveBase, + _bilinear_solve_discrete_lyapunov, block_diag, cholesky, solve, + solve_discrete_lyapunov, solve_triangular, ) @@ -783,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node): return [prod(det_sub_matrices)] -# @register_canonicalize -# @register_stabilize -# @node_rewriter([slogdet]) -# def rewrite_slogdet_blockdiag(fgraph, node): -# """ -# This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - -# slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) - -# Parameters -# ---------- -# fgraph: FunctionGraph -# Function graph being optimized -# node: Apply -# Node of the function graph to be optimized - -# Returns -# ------- -# list of Variable, optional -# List of optimized variables, or None if no optimization was performed -# """ -# # Check for inner block_diag operation -# potential_block_diag = node.inputs[0].owner -# if not ( -# potential_block_diag -# and isinstance(potential_block_diag.op, Blockwise) -# and isinstance(potential_block_diag.op.core_op, BlockDiagonal) -# ): -# return None - -# # Find the composing sub_matrices -# sub_matrices = potential_block_diag.inputs -# sign_sub_matrices, logdet_sub_matrices = zip( -# *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] -# ) - -# return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] - - @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag]) @@ -969,6 +934,25 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): return [eye_input * (non_eye_input**0.5)] +@node_rewriter([_bilinear_solve_discrete_lyapunov]) +def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): + """ + Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX + """ + A, B = (cast(TensorVariable, x) for x in node.inputs) + result = solve_discrete_lyapunov(A, B, method="direct") + + return [result] + + +optdb.register( + "jax_bilinaer_lyapunov_to_direct", + in2out(jax_bilinaer_lyapunov_to_direct), + "jax", + position=0.9, # Run before canonicalization +) + + @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): From 6e4f04b8d34fa9feca43b1aa340de7ea0dac911a Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 16 Nov 2024 13:06:45 +0530 Subject: [PATCH 12/16] minor changes --- pytensor/tensor/nlinalg.py | 5 ++- pytensor/tensor/rewriting/linalg.py | 51 +++++++++++++++++++-------- tests/tensor/rewriting/test_linalg.py | 2 +- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 30b48d2756..083b82d934 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -266,7 +266,10 @@ def __str__(self): return "SLogDet" -def slogdet(x): +def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: + """ + This function simplfies the slogdet operation into 2 separate operations using directly the det op : sign(det_val) and log(abs(det_val)) + """ det_val = det(x) return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index c6d3416334..cd202fe3ed 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -956,6 +956,21 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): + """ + This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + dictionary of Variables, optional + Dictionary of nodes and what they should be replaced with, or None if no optimization was performed + """ dummy_replacements = {} for client, _ in fgraph.clients[node.outputs[0]]: # Check for sign(det) @@ -964,13 +979,16 @@ def slogdet_specialization(fgraph, node): # Check for log(abs(det)) elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): + potential_log = None for client_2, _ in fgraph.clients[client.outputs[0]]: if isinstance(client_2.op, Elemwise) and isinstance( client_2.op.scalar_op, Log ): - dummy_replacements[ - fgraph.clients[client.outputs[0]][0][0].outputs[0] - ] = "log_abs_det" + potential_log = client_2 + if potential_log: + dummy_replacements[potential_log.outputs[0]] = "log_abs_det" + else: + return None # Check for log(det) elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): @@ -980,15 +998,18 @@ def slogdet_specialization(fgraph, node): else: return None - [x] = node.inputs - sign_det_x, log_abs_det_x = SLogDet()(x) - log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) - slogdet_specialization_map = { - "sign": sign_det_x, - "log_abs_det": log_abs_det_x, - "log_det": log_det_x, - } - replacements = { - k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() - } - return replacements or None + if not dummy_replacements: + return None + else: + [x] = node.inputs + sign_det_x, log_abs_det_x = SLogDet()(x) + log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) + slogdet_specialization_map = { + "sign": sign_det_x, + "log_abs_det": log_abs_det_x, + "log_det": log_det_x, + } + replacements = { + k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() + } + return replacements diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index fdcfc3ec91..8ac1bb188e 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -932,7 +932,7 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): assert any(isinstance(node.op, Cholesky) for node in nodes) -def test_slogdet_specialisation(): +def test_slogdet_specialization(): x, a = pt.dmatrix("x"), np.random.rand(20, 20) det_x, det_a = pt.linalg.det(x), np.linalg.det(a) log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a)) From 7a05cb1d20f0154899cb258d379620bcadbd51dd Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 16 Nov 2024 15:30:04 +0530 Subject: [PATCH 13/16] check for pytorch failing test --- pytensor/tensor/nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 083b82d934..1095918cbb 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -271,7 +271,7 @@ def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariab This function simplfies the slogdet operation into 2 separate operations using directly the det op : sign(det_val) and log(abs(det_val)) """ det_val = det(x) - return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) + return [ptm.sign(det_val), ptm.log(ptm.abs(det_val))] class Eig(Op): From 7f179af61e8e203cc85c25409d354a518472d696 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 16 Nov 2024 23:12:38 +0530 Subject: [PATCH 14/16] fixed mypy error and pytorch linking test --- pytensor/tensor/nlinalg.py | 24 +++++++++++++++++++++++- tests/link/pytorch/test_nlinalg.py | 6 ++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 1095918cbb..95ce771a66 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -268,7 +268,29 @@ def __str__(self): def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: """ - This function simplfies the slogdet operation into 2 separate operations using directly the det op : sign(det_val) and log(abs(det_val)) + Compute the sign and (natural) logarithm of the determinant of an array. + + Returns a naive graph which is optimized later using rewrites with the det operation. + + Parameters + ---------- + x : (..., M, M) tensor or tensor_like + Input tensor, has to be square. + + Returns + ------- + A namedtuple with the following attributes: + + sign : (...) tensor_like + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. For a complex matrix, this is a complex number + with absolute value 1 (i.e., it is on the unit circle), or else 0. + logabsdet : (...) tensor_like + The natural log of the absolute value of the determinant. + + If the determinant is zero, then `sign` will be 0 and `logabsdet` + will be -inf. In all cases, the determinant is equal to + ``sign * exp(logabsdet)``. """ det_val = det(x) return [ptm.sign(det_val), ptm.log(ptm.abs(det_val))] diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 7d69ac0500..55e7c447e3 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import numpy as np import pytest @@ -22,13 +24,13 @@ def matrix_test(): @pytest.mark.parametrize( "func", - (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), + (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), ) def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test out = func(x) - out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) + out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out]) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) From 912eab4e3f63b61ba935dfaa8bd61441490b4d3d Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 16 Nov 2024 23:16:45 +0530 Subject: [PATCH 15/16] forgot to add one small (major) change --- pytensor/tensor/nlinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 95ce771a66..7bc897f413 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -293,7 +293,7 @@ def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariab ``sign * exp(logabsdet)``. """ det_val = det(x) - return [ptm.sign(det_val), ptm.log(ptm.abs(det_val))] + return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) class Eig(Op): From 398f4ad9b471050039f9c8324432d6c9913bde4d Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 18 Nov 2024 02:04:33 +0530 Subject: [PATCH 16/16] fixed documentation --- pytensor/tensor/nlinalg.py | 8 ++++---- tests/tensor/rewriting/test_linalg.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 7bc897f413..47c6699cca 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -11,6 +11,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable, diagonal @@ -266,7 +267,7 @@ def __str__(self): return "SLogDet" -def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: +def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: """ Compute the sign and (natural) logarithm of the determinant of an array. @@ -279,12 +280,11 @@ def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariab Returns ------- - A namedtuple with the following attributes: + A tuple with the following attributes: sign : (...) tensor_like A number representing the sign of the determinant. For a real matrix, - this is 1, 0, or -1. For a complex matrix, this is a complex number - with absolute value 1 (i.e., it is on the unit circle), or else 0. + this is 1, 0, or -1. logabsdet : (...) tensor_like The natural log of the absolute value of the determinant. diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 8ac1bb188e..c9b9afff19 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -939,6 +939,7 @@ def test_slogdet_specialization(): log_det_x, log_det_a = pt.log(det_x), np.log(det_a) sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a) exp_det_x = pt.exp(det_x) + # REWRITE TESTS # sign(det(x)) f = function([x], [sign_det_x], mode="FAST_RUN") @@ -952,6 +953,7 @@ def test_slogdet_specialization(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + # log(abs(det(x))) f = function([x], [log_abs_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes @@ -964,6 +966,7 @@ def test_slogdet_specialization(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + # log(det(x)) f = function([x], [log_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes @@ -976,17 +979,20 @@ def test_slogdet_specialization(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) - # more than 1 valid function + + # More than 1 valid function f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 assert not any(isinstance(node.op, Det) for node in nodes) - # other functions (rewrite shouldnt be applied to these) - # only invalid functions + + # Other functions (rewrite shouldnt be applied to these) + # Only invalid functions f = function([x], [exp_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) - # invalid + valid function + + # Invalid + Valid function f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes)