Skip to content

Commit 8f6badf

Browse files
committed
splitting rewrite into 2 stages done
1 parent 6927516 commit 8f6badf

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -972,36 +972,39 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972972
@register_specialize
973973
@node_rewriter([det])
974974
def slogdet_specialization(fgraph, node):
975-
replacements = {}
975+
dummy_replacements = {}
976976
for client, _ in fgraph.clients[node.outputs[0]]:
977977
# Check for sign(det)
978978
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
979-
x = node.inputs[0]
980-
sign_det_x, slog_det_x = SLogDet()(x)
981-
replacements[client.outputs[0]] = sign_det_x
979+
dummy_replacements[client.outputs[0]] = "sign"
982980

983981
# Check for log(abs(det))
984982
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
985983
for client_2, _ in fgraph.clients[client.outputs[0]]:
986984
if isinstance(client_2.op, Elemwise) and isinstance(
987985
client_2.op.scalar_op, Log
988986
):
989-
x = node.inputs[0]
990-
sign_det_x, slog_det_x = SLogDet()(x)
991-
replacements[fgraph.clients[client.outputs[0]][0][0].outputs[0]] = (
992-
slog_det_x
993-
)
987+
dummy_replacements[
988+
fgraph.clients[client.outputs[0]][0][0].outputs[0]
989+
] = "log_abs_det"
994990

995991
# Check for log(det)
996992
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
997-
x = node.inputs[0]
998-
sign_det_x, slog_det_x = SLogDet()(x)
999-
replacements[client.outputs[0]] = pt.where(
1000-
pt.eq(sign_det_x, -1), np.nan, slog_det_x
1001-
)
993+
dummy_replacements[client.outputs[0]] = "log_det"
1002994

1003995
# Det is used directly for something else, don't rewrite to avoid computing two dets
1004996
else:
1005997
return None
1006998

999+
[x] = node.inputs
1000+
sign_det_x, log_abs_det_x = SLogDet()(x)
1001+
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
1002+
slogdet_specialization_map = {
1003+
"sign": sign_det_x,
1004+
"log_abs_det": log_abs_det_x,
1005+
"log_det": log_det_x,
1006+
}
1007+
replacements = {
1008+
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
1009+
}
10071010
return replacements or None

0 commit comments

Comments
 (0)