Skip to content

Commit d2b7336

Browse files
committed
updated checks for specialised rewrite
1 parent 7d1eadd commit d2b7336

File tree

1 file changed

+17
-26
lines changed

1 file changed

+17
-26
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -968,30 +968,17 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
968968

969969

970970
# SLogDet Rewrites
971-
def check_sign_det(node):
972-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sign)):
971+
def check_log_abs_det(fgraph, client):
972+
# First, we find abs
973+
if not (isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs)):
973974
return False
974975

975-
return True
976-
977-
978-
def check_log_abs_det(node):
979-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)):
980-
return False
981-
982-
potential_abs = node.inputs[0].owner
983-
if not (
984-
isinstance(potential_abs.op, Elemwise)
985-
and isinstance(potential_abs.op.scalar_op, Abs)
986-
):
987-
return False
988-
989-
return True
990-
991-
992-
def check_log_det(node):
993-
if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)):
994-
return False
976+
# Check whether log is a client of abs
977+
for client_2 in fgraph.clients[client.outputs[0]]:
978+
if not (
979+
isinstance(client_2.op, Elemwise) and isinstance(client_2.op.scalar_op, Log)
980+
):
981+
return False
995982

996983
return True
997984

@@ -1001,17 +988,21 @@ def slogdet_specialization(fgraph, node):
1001988
x = node.inputs[0]
1002989
sign_det_x, slog_det_x = SLogDet()(x)
1003990
replacements = {}
1004-
for client in list(fgraph.clients.keys()):
991+
for client in fgraph.clients[node.outputs[0]]:
1005992
# Check for sign(det)
1006-
if check_sign_det(client[0].owner):
993+
if isinstance(client[0].op, Elemwise) and isinstance(
994+
client[0].op.scalar_op, Sign
995+
):
1007996
replacements[client[0].owner.outputs[0]] = sign_det_x
1008997

1009998
# Check for log(abs(det))
1010-
elif check_log_abs_det(client[0].owner):
999+
elif check_log_abs_det(fgraph, client[0]):
10111000
replacements[client[0].owner.outputs[0]] = slog_det_x
10121001

10131002
# Check for log(det)
1014-
elif check_log_det(client[0].owner):
1003+
elif isinstance(client[0].op, Elemwise) and isinstance(
1004+
client[0].op.scalar_op, Log
1005+
):
10151006
pass
10161007
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
10171008

0 commit comments

Comments
 (0)