@@ -985,26 +985,30 @@ def check_log_abs_det(fgraph, client):
985
985
986
986
@node_rewriter (tracks = [det ])
987
987
def slogdet_specialization (fgraph , node ):
988
- x = node .inputs [0 ]
989
- sign_det_x , slog_det_x = SLogDet ()(x )
990
988
replacements = {}
991
989
for client in fgraph .clients [node .outputs [0 ]]:
992
990
# Check for sign(det)
993
991
if isinstance (client [0 ].op , Elemwise ) and isinstance (
994
992
client [0 ].op .scalar_op , Sign
995
993
):
996
- replacements [client [0 ].owner .outputs [0 ]] = sign_det_x
994
+ x = node .inputs [0 ]
995
+ sign_det_x , slog_det_x = SLogDet ()(x )
996
+ replacements [client [0 ].outputs [0 ]] = sign_det_x
997
997
998
998
# Check for log(abs(det))
999
999
elif check_log_abs_det (fgraph , client [0 ]):
1000
- replacements [client [0 ].owner .outputs [0 ]] = slog_det_x
1000
+ x = node .inputs [0 ]
1001
+ sign_det_x , slog_det_x = SLogDet ()(x )
1002
+ replacements [fgraph .clients [client [0 ].outputs [0 ]][0 ][0 ].outputs [0 ]] = (
1003
+ slog_det_x
1004
+ )
1001
1005
1002
1006
# Check for log(det)
1003
- elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1004
- client [0 ].op .scalar_op , Log
1005
- ):
1006
- pass
1007
- # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1007
+ # elif isinstance(client[0].op, Elemwise) and isinstance(
1008
+ # client[0].op.scalar_op, Log
1009
+ # ):
1010
+ # pass
1011
+ # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1008
1012
1009
1013
# Det is used directly for something else, don't rewrite to avoid computing two dets
1010
1014
else :
0 commit comments