|
49 | 49 | from pytensor.graph.op import Op
|
50 | 50 | from pytensor.graph.replace import clone_replace
|
51 | 51 | from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
|
52 |
| -from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt |
| 52 | +from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt |
53 | 53 | from pytensor.scan.op import Scan
|
54 | 54 | from pytensor.tensor.exceptions import NotScalarConstantError
|
55 | 55 | from pytensor.tensor.math import (
|
| 56 | + abs, |
56 | 57 | add,
|
57 | 58 | exp,
|
58 | 59 | log,
|
@@ -336,7 +337,7 @@ def apply(self, fgraph: FunctionGraph):
|
336 | 337 | class MeasurableTransform(MeasurableElemwise):
|
337 | 338 | """A placeholder used to specify a log-likelihood for a transformed measurable variable"""
|
338 | 339 |
|
339 |
| - valid_scalar_types = (Exp, Log, Add, Mul, Pow) |
| 340 | + valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs) |
340 | 341 |
|
341 | 342 | # Cannot use `transform` as name because it would clash with the property added by
|
342 | 343 | # the `TransformValuesRewrite`
|
@@ -498,7 +499,7 @@ def measurable_sub_to_neg(fgraph, node):
|
498 | 499 | return [at.add(minuend, at.neg(subtrahend))]
|
499 | 500 |
|
500 | 501 |
|
501 |
| -@node_rewriter([exp, log, add, mul, pow]) |
| 502 | +@node_rewriter([exp, log, add, mul, pow, abs]) |
502 | 503 | def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
|
503 | 504 | """Find measurable transformations from Elemwise operators."""
|
504 | 505 |
|
@@ -558,6 +559,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
|
558 | 559 | transform = ExpTransform()
|
559 | 560 | elif isinstance(scalar_op, Log):
|
560 | 561 | transform = LogTransform()
|
| 562 | + elif isinstance(scalar_op, Abs): |
| 563 | + transform = AbsTransform() |
561 | 564 | elif isinstance(scalar_op, Pow):
|
562 | 565 | # We only allow for the base to be measurable
|
563 | 566 | if measurable_input_idx != 0:
|
@@ -701,6 +704,20 @@ def log_jac_det(self, value, *inputs):
|
701 | 704 | return -at.log(value)
|
702 | 705 |
|
703 | 706 |
|
| 707 | +class AbsTransform(RVTransform): |
| 708 | + name = "abs" |
| 709 | + |
| 710 | + def forward(self, value, *inputs): |
| 711 | + return at.abs(value) |
| 712 | + |
| 713 | + def backward(self, value, *inputs): |
| 714 | + value = at.switch(value >= 0, value, np.nan) |
| 715 | + return -value, value |
| 716 | + |
| 717 | + def log_jac_det(self, value, *inputs): |
| 718 | + return at.switch(value >= 0, 0, np.nan) |
| 719 | + |
| 720 | + |
704 | 721 | class PowerTransform(RVTransform):
|
705 | 722 | name = "power"
|
706 | 723 |
|
|
0 commit comments