Skip to content

Commit 6767600

Browse files
committed
updated specialisation rewrite
1 parent fe83ebc commit 6767600

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -985,26 +985,30 @@ def check_log_abs_det(fgraph, client):
985985

986986
@node_rewriter(tracks=[det])
987987
def slogdet_specialization(fgraph, node):
988-
x = node.inputs[0]
989-
sign_det_x, slog_det_x = SLogDet()(x)
990988
replacements = {}
991989
for client in fgraph.clients[node.outputs[0]]:
992990
# Check for sign(det)
993991
if isinstance(client[0].op, Elemwise) and isinstance(
994992
client[0].op.scalar_op, Sign
995993
):
996-
replacements[client[0].owner.outputs[0]] = sign_det_x
994+
x = node.inputs[0]
995+
sign_det_x, slog_det_x = SLogDet()(x)
996+
replacements[client[0].outputs[0]] = sign_det_x
997997

998998
# Check for log(abs(det))
999999
elif check_log_abs_det(fgraph, client[0]):
1000-
replacements[client[0].owner.outputs[0]] = slog_det_x
1000+
x = node.inputs[0]
1001+
sign_det_x, slog_det_x = SLogDet()(x)
1002+
replacements[fgraph.clients[client[0].outputs[0]][0][0].outputs[0]] = (
1003+
slog_det_x
1004+
)
10011005

10021006
# Check for log(det)
1003-
elif isinstance(client[0].op, Elemwise) and isinstance(
1004-
client[0].op.scalar_op, Log
1005-
):
1006-
pass
1007-
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1007+
# elif isinstance(client[0].op, Elemwise) and isinstance(
1008+
# client[0].op.scalar_op, Log
1009+
# ):
1010+
# pass
1011+
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
10081012

10091013
# Det is used directly for something else, don't rewrite to avoid computing two dets
10101014
else:

0 commit comments

Comments
 (0)