@@ -972,36 +972,39 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
972
972
@register_specialize
973
973
@node_rewriter ([det ])
974
974
def slogdet_specialization (fgraph , node ):
975
- replacements = {}
975
+ dummy_replacements = {}
976
976
for client , _ in fgraph .clients [node .outputs [0 ]]:
977
977
# Check for sign(det)
978
978
if isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Sign ):
979
- x = node .inputs [0 ]
980
- sign_det_x , slog_det_x = SLogDet ()(x )
981
- replacements [client .outputs [0 ]] = sign_det_x
979
+ dummy_replacements [client .outputs [0 ]] = "sign"
982
980
983
981
# Check for log(abs(det))
984
982
elif isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs ):
985
983
for client_2 , _ in fgraph .clients [client .outputs [0 ]]:
986
984
if isinstance (client_2 .op , Elemwise ) and isinstance (
987
985
client_2 .op .scalar_op , Log
988
986
):
989
- x = node .inputs [0 ]
990
- sign_det_x , slog_det_x = SLogDet ()(x )
991
- replacements [fgraph .clients [client .outputs [0 ]][0 ][0 ].outputs [0 ]] = (
992
- slog_det_x
993
- )
987
+ dummy_replacements [
988
+ fgraph .clients [client .outputs [0 ]][0 ][0 ].outputs [0 ]
989
+ ] = "log_abs_det"
994
990
995
991
# Check for log(det)
996
992
elif isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Log ):
997
- x = node .inputs [0 ]
998
- sign_det_x , slog_det_x = SLogDet ()(x )
999
- replacements [client .outputs [0 ]] = pt .where (
1000
- pt .eq (sign_det_x , - 1 ), np .nan , slog_det_x
1001
- )
993
+ dummy_replacements [client .outputs [0 ]] = "log_det"
1002
994
1003
995
# Det is used directly for something else, don't rewrite to avoid computing two dets
1004
996
else :
1005
997
return None
1006
998
999
+ [x ] = node .inputs
1000
+ sign_det_x , log_abs_det_x = SLogDet ()(x )
1001
+ log_det_x = pt .where (pt .eq (sign_det_x , - 1 ), np .nan , log_abs_det_x )
1002
+ slogdet_specialization_map = {
1003
+ "sign" : sign_det_x ,
1004
+ "log_abs_det" : log_abs_det_x ,
1005
+ "log_det" : log_det_x ,
1006
+ }
1007
+ replacements = {
1008
+ k : slogdet_specialization_map [v ] for k , v in dummy_replacements .items ()
1009
+ }
1007
1010
return replacements or None
0 commit comments