Skip to content

Commit 6b46ffd

Browse files
committed
added specialised rewrite for slogdet
1 parent 7c46f41 commit 6b46ffd

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
copy_stack_trace,
1010
node_rewriter,
1111
)
12-
from pytensor.scalar.basic import Mul
12+
from pytensor.scalar.basic import Abs, Log, Mul, Sign
1313
from pytensor.tensor.basic import (
1414
AllocDiag,
1515
ExtractDiag,
@@ -28,6 +28,7 @@
2828
KroneckerProduct,
2929
MatrixInverse,
3030
MatrixPinv,
31+
SLogDet,
3132
det,
3233
inv,
3334
kron,
@@ -964,3 +965,58 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
964965
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
965966

966967
return [eye_input * (non_eye_input**0.5)]
968+
969+
970+
# SLogDet Rewrites
971+
def check_sign_det(node):
972+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sign)):
973+
return False
974+
975+
return True
976+
977+
978+
def check_log_abs_det(node):
979+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)):
980+
return False
981+
982+
potential_abs = node.inputs[0].owner
983+
if not (
984+
isinstance(potential_abs.op, Elemwise)
985+
and isinstance(potential_abs.op.scalar_op, Abs)
986+
):
987+
return False
988+
989+
return True
990+
991+
992+
def check_log_det(node):
993+
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)):
994+
return False
995+
996+
return True
997+
998+
999+
@node_rewriter(tracks=[det])
1000+
def slogdet_specialization(fgraph, node):
1001+
x = node.inputs[0]
1002+
sign_det_x, slog_det_x = SLogDet()(x)
1003+
replacements = {}
1004+
for client in list(fgraph.clients.keys()):
1005+
# Check for sign(det)
1006+
if check_sign_det(client[0].owner):
1007+
replacements[client[0].owner.outputs[0]] = sign_det_x
1008+
1009+
# Check for log(abs(det))
1010+
elif check_log_abs_det(client[0].owner):
1011+
replacements[client[0].owner.outputs[0]] = slog_det_x
1012+
1013+
# Check for log(det)
1014+
elif check_log_det(client[0].owner):
1015+
pass
1016+
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1017+
1018+
# Det is used directly for something else, don't rewrite to avoid computing two dets
1019+
else:
1020+
return None
1021+
1022+
return replacements or None

0 commit comments

Comments
 (0)