Skip to content

Commit 6927516

Browse files
committed
added all tests
1 parent 03215e8 commit 6927516

File tree

3 files changed

+92
-40
lines changed

3 files changed

+92
-40
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def __str__(self):
266266
return "SLogDet"
267267

268268

269-
# slogdet = Blockwise(SLogDet())
270269
def slogdet(x):
271270
det_val = det(x)
272271
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))

pytensor/tensor/rewriting/linalg.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -969,50 +969,34 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
969969
return [eye_input * (non_eye_input**0.5)]
970970

971971

972-
def _check_log_abs_det(fgraph, client):
973-
# First, we find abs
974-
if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)):
975-
return False
976-
977-
# Check whether log is a client of abs
978-
for client_2 in fgraph.clients[client.outputs[0]]:
979-
if not (
980-
isinstance(client_2[0].op, Elemwise)
981-
and isinstance(client_2[0].op.scalar_op, Log)
982-
):
983-
return False
984-
985-
return True
986-
987-
988972
@register_specialize
989973
@node_rewriter([det])
990974
def slogdet_specialization(fgraph, node):
991975
replacements = {}
992-
for client in fgraph.clients[node.outputs[0]]:
976+
for client, _ in fgraph.clients[node.outputs[0]]:
993977
# Check for sign(det)
994-
if isinstance(client[0].op, Elemwise) and isinstance(
995-
client[0].op.scalar_op, Sign
996-
):
978+
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
997979
x = node.inputs[0]
998980
sign_det_x, slog_det_x = SLogDet()(x)
999-
replacements[client[0].outputs[0]] = sign_det_x
981+
replacements[client.outputs[0]] = sign_det_x
1000982

1001983
# Check for log(abs(det))
1002-
elif _check_log_abs_det(fgraph, client[0]):
1003-
x = node.inputs[0]
1004-
sign_det_x, slog_det_x = SLogDet()(x)
1005-
replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = (
1006-
slog_det_x
1007-
)
984+
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
985+
for client_2, _ in fgraph.clients[client.outputs[0]]:
986+
if isinstance(client_2.op, Elemwise) and isinstance(
987+
client_2.op.scalar_op, Log
988+
):
989+
x = node.inputs[0]
990+
sign_det_x, slog_det_x = SLogDet()(x)
991+
replacements[fgraph.clients[client.outputs[0]][0][0].outputs[0]] = (
992+
slog_det_x
993+
)
1008994

1009995
# Check for log(det)
1010-
elif isinstance(client[0].op, Elemwise) and isinstance(
1011-
client[0].op.scalar_op, Log
1012-
):
996+
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
1013997
x = node.inputs[0]
1014998
sign_det_x, slog_det_x = SLogDet()(x)
1015-
replacements[client[0].outputs[0]] = pt.where(
999+
replacements[client.outputs[0]] = pt.where(
10161000
pt.eq(sign_det_x, -1), np.nan, slog_det_x
10171001
)
10181002

tests/tensor/rewriting/test_linalg.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def test_det_kronecker_rewrite():
781781
a, b = pt.dmatrices("a", "b")
782782
kron_prod = pt.linalg.kron(a, b)
783783
det_output = pt.linalg.det(kron_prod)
784-
f_rewritten = function([kron_prod], [det_output], mode="FAST_RUN")
784+
f_rewritten = function([a, b], [det_output], mode="FAST_RUN")
785785

786786
# Rewrite Test
787787
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -791,7 +791,7 @@ def test_det_kronecker_rewrite():
791791
a_test, b_test = np.random.rand(2, 20, 20)
792792
kron_prod_test = np.kron(a_test, b_test)
793793
det_output_test = np.linalg.det(kron_prod_test)
794-
rewritten_det_val = f_rewritten(kron_prod_test)
794+
rewritten_det_val = f_rewritten(a_test, b_test)
795795
assert_allclose(
796796
det_output_test,
797797
rewritten_det_val,
@@ -800,6 +800,35 @@ def test_det_kronecker_rewrite():
800800
)
801801

802802

803+
def test_slogdet_kronecker_rewrite():
804+
a, b = pt.dmatrices("a", "b")
805+
kron_prod = pt.linalg.kron(a, b)
806+
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
807+
f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN")
808+
809+
# Rewrite Test
810+
nodes = f_rewritten.maker.fgraph.apply_nodes
811+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
812+
813+
# Value Test
814+
a_test, b_test = np.random.rand(2, 20, 20)
815+
kron_prod_test = np.kron(a_test, b_test)
816+
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
817+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test)
818+
assert_allclose(
819+
sign_output_test,
820+
rewritten_sign_val,
821+
atol=1e-3 if config.floatX == "float32" else 1e-8,
822+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
823+
)
824+
assert_allclose(
825+
logdet_output_test,
826+
rewritten_logdet_val,
827+
atol=1e-3 if config.floatX == "float32" else 1e-8,
828+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
829+
)
830+
831+
803832
def test_cholesky_eye_rewrite():
804833
x = pt.eye(10)
805834
L = pt.linalg.cholesky(x)
@@ -904,20 +933,60 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
904933

905934

906935
def test_slogdet_specialisation():
907-
x = pt.dmatrix("x")
908-
det_x = pt.linalg.det(x)
909-
log_abs_det_x = pt.log(pt.abs(det_x))
910-
sign_det_x = pt.sign(det_x)
936+
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
937+
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
938+
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))
939+
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
940+
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
911941
exp_det_x = pt.exp(det_x)
942+
# REWRITE TESTS
912943
# sign(det(x))
913944
f = function([x], [sign_det_x], mode="FAST_RUN")
914945
nodes = f.maker.fgraph.apply_nodes
915-
assert any(isinstance(node.op, SLogDet) for node in nodes)
946+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
947+
assert not any(isinstance(node.op, Det) for node in nodes)
948+
rw_sign_det_a = f(a)
949+
assert_allclose(
950+
sign_det_a,
951+
rw_sign_det_a,
952+
atol=1e-3 if config.floatX == "float32" else 1e-8,
953+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
954+
)
916955
# log(abs(det(x)))
917956
f = function([x], [log_abs_det_x], mode="FAST_RUN")
918957
nodes = f.maker.fgraph.apply_nodes
919-
assert any(isinstance(node.op, SLogDet) for node in nodes)
958+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
959+
assert not any(isinstance(node.op, Det) for node in nodes)
960+
rw_log_abs_det_a = f(a)
961+
assert_allclose(
962+
log_abs_det_a,
963+
rw_log_abs_det_a,
964+
atol=1e-3 if config.floatX == "float32" else 1e-8,
965+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
966+
)
967+
# log(det(x))
968+
f = function([x], [log_det_x], mode="FAST_RUN")
969+
nodes = f.maker.fgraph.apply_nodes
970+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
971+
assert not any(isinstance(node.op, Det) for node in nodes)
972+
rw_log_det_a = f(a)
973+
assert_allclose(
974+
log_det_a,
975+
rw_log_det_a,
976+
atol=1e-3 if config.floatX == "float32" else 1e-8,
977+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
978+
)
979+
# more than 1 valid function
980+
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
981+
nodes = f.maker.fgraph.apply_nodes
982+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
983+
assert not any(isinstance(node.op, Det) for node in nodes)
920984
# other functions (rewrite shouldnt be applied to these)
985+
# only invalid functions
921986
f = function([x], [exp_det_x], mode="FAST_RUN")
922987
nodes = f.maker.fgraph.apply_nodes
923988
assert not any(isinstance(node.op, SLogDet) for node in nodes)
989+
# invalid + valid function
990+
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
991+
nodes = f.maker.fgraph.apply_nodes
992+
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 commit comments

Comments
 (0)