Skip to content

Commit 7c46f41

Browse files
committed
removed rewrites for slogdet and added the same for det which will be used later
1 parent c930399 commit 7c46f41

File tree

2 files changed

+35
-42
lines changed

2 files changed

+35
-42
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -853,39 +853,38 @@ def rewrite_diag_kronecker(fgraph, node):
853853
return [outer_prod_as_vector]
854854

855855

856-
# @register_canonicalize
857-
# @register_stabilize
858-
# @node_rewriter([slogdet])
859-
# def rewrite_slogdet_kronecker(fgraph, node):
860-
# """
861-
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
856+
@register_canonicalize
857+
@register_stabilize
858+
@node_rewriter([det])
859+
def rewrite_det_kronecker(fgraph, node):
860+
"""
861+
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
862862
863-
# Parameters
864-
# ----------
865-
# fgraph: FunctionGraph
866-
# Function graph being optimized
867-
# node: Apply
868-
# Node of the function graph to be optimized
863+
Parameters
864+
----------
865+
fgraph: FunctionGraph
866+
Function graph being optimized
867+
node: Apply
868+
Node of the function graph to be optimized
869869
870-
# Returns
871-
# -------
872-
# list of Variable, optional
873-
# List of optimized variables, or None if no optimization was performed
874-
# """
875-
# # Check for inner kron operation
876-
# potential_kron = node.inputs[0].owner
877-
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
878-
# return None
870+
Returns
871+
-------
872+
list of Variable, optional
873+
List of optimized variables, or None if no optimization was performed
874+
"""
875+
# Check for inner kron operation
876+
potential_kron = node.inputs[0].owner
877+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
878+
return None
879879

880-
# # Find the matrices
881-
# a, b = potential_kron.inputs
882-
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
883-
# sizes = [a.shape[-1], b.shape[-1]]
884-
# prod_sizes = prod(sizes, no_zeros_in_input=True)
885-
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
886-
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
880+
# Find the matrices
881+
a, b = potential_kron.inputs
882+
dets = [det(a), det(b)]
883+
sizes = [a.shape[-1], b.shape[-1]]
884+
prod_sizes = prod(sizes, no_zeros_in_input=True)
885+
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])
887886

888-
# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
887+
return [det_final]
889888

890889

891890
@register_canonicalize

tests/tensor/rewriting/test_linalg.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,11 @@ def test_diag_kronecker_rewrite():
776776
)
777777

778778

779-
def test_slogdet_kronecker_rewrite():
779+
def test_det_kronecker_rewrite():
780780
a, b = pt.dmatrices("a", "b")
781781
kron_prod = pt.linalg.kron(a, b)
782-
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783-
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
782+
det_output = pt.linalg.det(kron_prod)
783+
f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN")
784784

785785
# Rewrite Test
786786
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -789,17 +789,11 @@ def test_slogdet_kronecker_rewrite():
789789
# Value Test
790790
a_test, b_test = np.random.rand(2, 20, 20)
791791
kron_prod_test = np.kron(a_test, b_test)
792-
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793-
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
792+
det_output_test = np.linalg.det(kron_prod_test)
793+
rewritten_det_val = f_rewritten(kron_prod_test)
794794
assert_allclose(
795-
sign_output_test,
796-
rewritten_sign_val,
797-
atol=1e-3 if config.floatX == "float32" else 1e-8,
798-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
799-
)
800-
assert_allclose(
801-
logdet_output_test,
802-
rewritten_logdet_val,
795+
det_output_test,
796+
rewritten_det_val,
803797
atol=1e-3 if config.floatX == "float32" else 1e-8,
804798
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805799
)

0 commit comments

Comments
 (0)