@@ -968,30 +968,17 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
968
968
969
969
970
970
# SLogDet Rewrites
971
- def check_sign_det (node ):
972
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Sign )):
971
+ def check_log_abs_det (fgraph , client ):
972
+ # First, we find abs
973
+ if not (isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs )):
973
974
return False
974
975
975
- return True
976
-
977
-
978
- def check_log_abs_det (node ):
979
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Log )):
980
- return False
981
-
982
- potential_abs = node .inputs [0 ].owner
983
- if not (
984
- isinstance (potential_abs .op , Elemwise )
985
- and isinstance (potential_abs .op .scalar_op , Abs )
986
- ):
987
- return False
988
-
989
- return True
990
-
991
-
992
- def check_log_det (node ):
993
- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Log )):
994
- return False
976
+ # Check whether log is a client of abs
977
+ for client_2 in fgraph .clients [client .outputs [0 ]]:
978
+ if not (
979
+ isinstance (client_2 .op , Elemwise ) and isinstance (client_2 .op .scalar_op , Log )
980
+ ):
981
+ return False
995
982
996
983
return True
997
984
@@ -1001,17 +988,21 @@ def slogdet_specialization(fgraph, node):
1001
988
x = node .inputs [0 ]
1002
989
sign_det_x , slog_det_x = SLogDet ()(x )
1003
990
replacements = {}
1004
- for client in list ( fgraph .clients . keys ()) :
991
+ for client in fgraph .clients [ node . outputs [ 0 ]] :
1005
992
# Check for sign(det)
1006
- if check_sign_det (client [0 ].owner ):
993
+ if isinstance (client [0 ].op , Elemwise ) and isinstance (
994
+ client [0 ].op .scalar_op , Sign
995
+ ):
1007
996
replacements [client [0 ].owner .outputs [0 ]] = sign_det_x
1008
997
1009
998
# Check for log(abs(det))
1010
- elif check_log_abs_det (client [0 ]. owner ):
999
+ elif check_log_abs_det (fgraph , client [0 ]):
1011
1000
replacements [client [0 ].owner .outputs [0 ]] = slog_det_x
1012
1001
1013
1002
# Check for log(det)
1014
- elif check_log_det (client [0 ].owner ):
1003
+ elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1004
+ client [0 ].op .scalar_op , Log
1005
+ ):
1015
1006
pass
1016
1007
# replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1017
1008
0 commit comments