|
9 | 9 | copy_stack_trace,
|
10 | 10 | node_rewriter,
|
11 | 11 | )
|
12 |
| -from pytensor.scalar.basic import Mul |
| 12 | +from pytensor.scalar.basic import Abs, Log, Mul, Sign |
13 | 13 | from pytensor.tensor.basic import (
|
14 | 14 | AllocDiag,
|
15 | 15 | ExtractDiag,
|
|
28 | 28 | KroneckerProduct,
|
29 | 29 | MatrixInverse,
|
30 | 30 | MatrixPinv,
|
| 31 | + SLogDet, |
31 | 32 | det,
|
32 | 33 | inv,
|
33 | 34 | kron,
|
@@ -964,3 +965,58 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
|
964 | 965 | non_eye_input = pt.shape_padaxis(non_eye_input, -2)
|
965 | 966 |
|
966 | 967 | return [eye_input * (non_eye_input**0.5)]
|
| 968 | + |
| 969 | + |
| 970 | +# SLogDet Rewrites |
| 971 | +def check_sign_det(node): |
| 972 | + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Sign)): |
| 973 | + return False |
| 974 | + |
| 975 | + return True |
| 976 | + |
| 977 | + |
| 978 | +def check_log_abs_det(node): |
| 979 | + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): |
| 980 | + return False |
| 981 | + |
| 982 | + potential_abs = node.inputs[0].owner |
| 983 | + if not ( |
| 984 | + isinstance(potential_abs.op, Elemwise) |
| 985 | + and isinstance(potential_abs.op.scalar_op, Abs) |
| 986 | + ): |
| 987 | + return False |
| 988 | + |
| 989 | + return True |
| 990 | + |
| 991 | + |
| 992 | +def check_log_det(node): |
| 993 | + if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Log)): |
| 994 | + return False |
| 995 | + |
| 996 | + return True |
| 997 | + |
| 998 | + |
| 999 | +@node_rewriter(tracks=[det]) |
| 1000 | +def slogdet_specialization(fgraph, node): |
| 1001 | + x = node.inputs[0] |
| 1002 | + sign_det_x, slog_det_x = SLogDet()(x) |
| 1003 | + replacements = {} |
| 1004 | + for client in list(fgraph.clients.keys()): |
| 1005 | + # Check for sign(det) |
| 1006 | + if check_sign_det(client[0].owner): |
| 1007 | + replacements[client[0].owner.outputs[0]] = sign_det_x |
| 1008 | + |
| 1009 | + # Check for log(abs(det)) |
| 1010 | + elif check_log_abs_det(client[0].owner): |
| 1011 | + replacements[client[0].owner.outputs[0]] = slog_det_x |
| 1012 | + |
| 1013 | + # Check for log(det) |
| 1014 | + elif check_log_det(client[0].owner): |
| 1015 | + pass |
| 1016 | + # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x) |
| 1017 | + |
| 1018 | + # Det is used directly for something else, don't rewrite to avoid computing two dets |
| 1019 | + else: |
| 1020 | + return None |
| 1021 | + |
| 1022 | + return replacements or None |
0 commit comments