Skip to content

Commit db11a23

Browse files
ricardoV94Luke LB
and
Luke LB
committed
Infer logrob of absolute transform
Co-authored-by: Luke LB <[email protected]>
1 parent 2953c8b commit db11a23

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

pymc/logprob/transforms.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@
4949
from pytensor.graph.op import Op
5050
from pytensor.graph.replace import clone_replace
5151
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
52-
from pytensor.scalar import Add, Exp, Log, Mul, Pow, Sqr, Sqrt
52+
from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt
5353
from pytensor.scan.op import Scan
5454
from pytensor.tensor.exceptions import NotScalarConstantError
5555
from pytensor.tensor.math import (
56+
abs,
5657
add,
5758
exp,
5859
log,
@@ -336,7 +337,7 @@ def apply(self, fgraph: FunctionGraph):
336337
class MeasurableTransform(MeasurableElemwise):
337338
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
338339

339-
valid_scalar_types = (Exp, Log, Add, Mul, Pow)
340+
valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs)
340341

341342
# Cannot use `transform` as name because it would clash with the property added by
342343
# the `TransformValuesRewrite`
@@ -498,7 +499,7 @@ def measurable_sub_to_neg(fgraph, node):
498499
return [at.add(minuend, at.neg(subtrahend))]
499500

500501

501-
@node_rewriter([exp, log, add, mul, pow])
502+
@node_rewriter([exp, log, add, mul, pow, abs])
502503
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
503504
"""Find measurable transformations from Elemwise operators."""
504505

@@ -558,6 +559,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
558559
transform = ExpTransform()
559560
elif isinstance(scalar_op, Log):
560561
transform = LogTransform()
562+
elif isinstance(scalar_op, Abs):
563+
transform = AbsTransform()
561564
elif isinstance(scalar_op, Pow):
562565
# We only allow for the base to be measurable
563566
if measurable_input_idx != 0:
@@ -701,6 +704,20 @@ def log_jac_det(self, value, *inputs):
701704
return -at.log(value)
702705

703706

707+
class AbsTransform(RVTransform):
708+
name = "abs"
709+
710+
def forward(self, value, *inputs):
711+
return at.abs(value)
712+
713+
def backward(self, value, *inputs):
714+
value = at.switch(value >= 0, value, np.nan)
715+
return -value, value
716+
717+
def log_jac_det(self, value, *inputs):
718+
return at.switch(value >= 0, 0, np.nan)
719+
720+
704721
class PowerTransform(RVTransform):
705722
name = "power"
706723

pymc/tests/logprob/test_transforms.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,19 @@ def test_negative_value_frac_power_transform(power):
876876
assert np.isneginf(x_logp_fn(-2.5))
877877

878878

879+
@pytest.mark.parametrize("test_val", (2.5, -2.5))
880+
def test_absolute_transform(test_val):
881+
x_rv = at.abs(at.random.normal())
882+
y_rv = at.random.halfnormal()
883+
884+
x_vv = x_rv.clone()
885+
y_vv = y_rv.clone()
886+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}, sum=False))
887+
y_logp_fn = pytensor.function([y_vv], joint_logprob({y_rv: y_vv}, sum=False))
888+
889+
assert np.allclose(x_logp_fn(test_val), y_logp_fn(test_val))
890+
891+
879892
def test_negated_rv_transform():
880893
x_rv = -at.random.halfnormal()
881894
x_rv.name = "x"

0 commit comments

Comments
 (0)