Skip to content

Commit 6826803

Browse files
committed
removed rewrites for slogdet and added the same for det which will be used later
1 parent 2859dbb commit 6826803

File tree

2 files changed

+35
-42
lines changed

2 files changed

+35
-42
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -857,39 +857,38 @@ def rewrite_diag_kronecker(fgraph, node):
857857
return [outer_prod_as_vector]
858858

859859

860-
# @register_canonicalize
861-
# @register_stabilize
862-
# @node_rewriter([slogdet])
863-
# def rewrite_slogdet_kronecker(fgraph, node):
864-
# """
865-
# This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
860+
@register_canonicalize
861+
@register_stabilize
862+
@node_rewriter([det])
863+
def rewrite_det_kronecker(fgraph, node):
864+
"""
865+
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
866866
867-
# Parameters
868-
# ----------
869-
# fgraph: FunctionGraph
870-
# Function graph being optimized
871-
# node: Apply
872-
# Node of the function graph to be optimized
867+
Parameters
868+
----------
869+
fgraph: FunctionGraph
870+
Function graph being optimized
871+
node: Apply
872+
Node of the function graph to be optimized
873873
874-
# Returns
875-
# -------
876-
# list of Variable, optional
877-
# List of optimized variables, or None if no optimization was performed
878-
# """
879-
# # Check for inner kron operation
880-
# potential_kron = node.inputs[0].owner
881-
# if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
882-
# return None
874+
Returns
875+
-------
876+
list of Variable, optional
877+
List of optimized variables, or None if no optimization was performed
878+
"""
879+
# Check for inner kron operation
880+
potential_kron = node.inputs[0].owner
881+
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
882+
return None
883883

884-
# # Find the matrices
885-
# a, b = potential_kron.inputs
886-
# signs, logdets = zip(*[slogdet(a), slogdet(b)])
887-
# sizes = [a.shape[-1], b.shape[-1]]
888-
# prod_sizes = prod(sizes, no_zeros_in_input=True)
889-
# signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
890-
# logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
884+
# Find the matrices
885+
a, b = potential_kron.inputs
886+
dets = [det(a), det(b)]
887+
sizes = [a.shape[-1], b.shape[-1]]
888+
prod_sizes = prod(sizes, no_zeros_in_input=True)
889+
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])
891890

892-
# return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
891+
return [det_final]
893892

894893

895894
@register_canonicalize

tests/tensor/rewriting/test_linalg.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,11 @@ def test_diag_kronecker_rewrite():
776776
)
777777

778778

779-
def test_slogdet_kronecker_rewrite():
779+
def test_det_kronecker_rewrite():
780780
a, b = pt.dmatrices("a", "b")
781781
kron_prod = pt.linalg.kron(a, b)
782-
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783-
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
782+
det_output = pt.linalg.det(kron_prod)
783+
f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN")
784784

785785
# Rewrite Test
786786
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -789,17 +789,11 @@ def test_slogdet_kronecker_rewrite():
789789
# Value Test
790790
a_test, b_test = np.random.rand(2, 20, 20)
791791
kron_prod_test = np.kron(a_test, b_test)
792-
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793-
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
792+
det_output_test = np.linalg.det(kron_prod_test)
793+
rewritten_det_val = f_rewritten(kron_prod_test)
794794
assert_allclose(
795-
sign_output_test,
796-
rewritten_sign_val,
797-
atol=1e-3 if config.floatX == "float32" else 1e-8,
798-
rtol=1e-3 if config.floatX == "float32" else 1e-8,
799-
)
800-
assert_allclose(
801-
logdet_output_test,
802-
rewritten_logdet_val,
795+
det_output_test,
796+
rewritten_det_val,
803797
atol=1e-3 if config.floatX == "float32" else 1e-8,
804798
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805799
)

0 commit comments

Comments
 (0)