Skip to content

Commit 816cc77

Browse files
committed
working specialised rewrite + test
1 parent 6767600 commit 816cc77

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections.abc import Callable
33
from typing import cast
44

5+
import numpy as np
6+
57
from pytensor import Variable
68
from pytensor import tensor as pt
79
from pytensor.graph import Apply, FunctionGraph
@@ -967,23 +969,24 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
967969
return [eye_input * (non_eye_input**0.5)]
968970

969971

970-
# SLogDet Rewrites
971-
def check_log_abs_det(fgraph, client):
972+
def _check_log_abs_det(fgraph, client):
972973
# First, we find abs
973974
if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)):
974975
return False
975976

976977
# Check whether log is a client of abs
977978
for client_2 in fgraph.clients[client.outputs[0]]:
978979
if not (
979-
isinstance(client_2.op, Elemwise) and isinstance(client_2.op.scalar_op, Log)
980+
isinstance(client_2[0].op, Elemwise)
981+
and isinstance(client_2[0].op.scalar_op, Log)
980982
):
981983
return False
982984

983985
return True
984986

985987

986-
@node_rewriter(tracks=[det])
988+
@register_specialize
989+
@node_rewriter([det])
987990
def slogdet_specialization(fgraph, node):
988991
replacements = {}
989992
for client in fgraph.clients[node.outputs[0]]:
@@ -996,19 +999,22 @@ def slogdet_specialization(fgraph, node):
996999
replacements[client[0].outputs[0]] = sign_det_x
9971000

9981001
# Check for log(abs(det))
999-
elif check_log_abs_det(fgraph, client[0]):
1002+
elif _check_log_abs_det(fgraph, client[0]):
10001003
x = node.inputs[0]
10011004
sign_det_x, slog_det_x = SLogDet()(x)
10021005
replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = (
10031006
slog_det_x
10041007
)
10051008

10061009
# Check for log(det)
1007-
# elif isinstance(client[0].op, Elemwise) and isinstance(
1008-
# client[0].op.scalar_op, Log
1009-
# ):
1010-
# pass
1011-
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1010+
elif isinstance(client[0].op, Elemwise) and isinstance(
1011+
client[0].op.scalar_op, Log
1012+
):
1013+
x = node.inputs[0]
1014+
sign_det_x, slog_det_x = SLogDet()(x)
1015+
replacements[client[0].outputs[0]] = pt.where(
1016+
pt.eq(sign_det_x, -1), np.nan, slog_det_x
1017+
)
10121018

10131019
# Det is used directly for something else, don't rewrite to avoid computing two dets
10141020
else:

tests/tensor/rewriting/test_linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
KroneckerProduct,
2222
MatrixInverse,
2323
MatrixPinv,
24+
SLogDet,
2425
matrix_inverse,
2526
svd,
2627
)
@@ -900,3 +901,23 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
900901
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
901902
nodes = f_rewritten.maker.fgraph.apply_nodes
902903
assert any(isinstance(node.op, Cholesky) for node in nodes)
904+
905+
906+
def test_slogdet_specialisation():
907+
x = pt.dmatrix("x")
908+
det_x = pt.linalg.det(x)
909+
log_abs_det_x = pt.log(pt.abs(det_x))
910+
sign_det_x = pt.sign(det_x)
911+
exp_det_x = pt.exp(det_x)
912+
# sign(det(x))
913+
f = function([x], [sign_det_x], mode="FAST_RUN")
914+
nodes = f.maker.fgraph.apply_nodes
915+
assert any(isinstance(node.op, SLogDet) for node in nodes)
916+
# log(abs(det(x)))
917+
f = function([x], [log_abs_det_x], mode="FAST_RUN")
918+
nodes = f.maker.fgraph.apply_nodes
919+
assert any(isinstance(node.op, SLogDet) for node in nodes)
920+
# other functions (rewrite shouldnt be applied to these)
921+
f = function([x], [exp_det_x], mode="FAST_RUN")
922+
nodes = f.maker.fgraph.apply_nodes
923+
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 commit comments

Comments
 (0)