2
2
from collections .abc import Callable
3
3
from typing import cast
4
4
5
+ import numpy as np
6
+
5
7
from pytensor import Variable
6
8
from pytensor import tensor as pt
7
9
from pytensor .graph import Apply , FunctionGraph
@@ -967,23 +969,24 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
967
969
return [eye_input * (non_eye_input ** 0.5 )]
968
970
969
971
970
- # SLogDet Rewrites
971
- def check_log_abs_det (fgraph , client ):
972
+ def _check_log_abs_det (fgraph , client ):
972
973
# First, we find abs
973
974
if not (isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs )):
974
975
return False
975
976
976
977
# Check whether log is a client of abs
977
978
for client_2 in fgraph .clients [client .outputs [0 ]]:
978
979
if not (
979
- isinstance (client_2 .op , Elemwise ) and isinstance (client_2 .op .scalar_op , Log )
980
+ isinstance (client_2 [0 ].op , Elemwise )
981
+ and isinstance (client_2 [0 ].op .scalar_op , Log )
980
982
):
981
983
return False
982
984
983
985
return True
984
986
985
987
986
- @node_rewriter (tracks = [det ])
988
+ @register_specialize
989
+ @node_rewriter ([det ])
987
990
def slogdet_specialization (fgraph , node ):
988
991
replacements = {}
989
992
for client in fgraph .clients [node .outputs [0 ]]:
@@ -996,19 +999,22 @@ def slogdet_specialization(fgraph, node):
996
999
replacements [client [0 ].outputs [0 ]] = sign_det_x
997
1000
998
1001
# Check for log(abs(det))
999
- elif check_log_abs_det (fgraph , client [0 ]):
1002
+ elif _check_log_abs_det (fgraph , client [0 ]):
1000
1003
x = node .inputs [0 ]
1001
1004
sign_det_x , slog_det_x = SLogDet ()(x )
1002
1005
replacements [fgraph .clients [client [0 ].outputs [0 ]][0 ][0 ].outputs [0 ]] = (
1003
1006
slog_det_x
1004
1007
)
1005
1008
1006
1009
# Check for log(det)
1007
- # elif isinstance(client[0].op, Elemwise) and isinstance(
1008
- # client[0].op.scalar_op, Log
1009
- # ):
1010
- # pass
1011
- # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1010
+ elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1011
+ client [0 ].op .scalar_op , Log
1012
+ ):
1013
+ x = node .inputs [0 ]
1014
+ sign_det_x , slog_det_x = SLogDet ()(x )
1015
+ replacements [client [0 ].outputs [0 ]] = pt .where (
1016
+ pt .eq (sign_det_x , - 1 ), np .nan , slog_det_x
1017
+ )
1012
1018
1013
1019
# Det is used directly for something else, don't rewrite to avoid computing two dets
1014
1020
else :
0 commit comments