Skip to content

Commit 39a2975

Browse files
authored
Derive logprob for exp2, log2, log10, log1p, expm1, log1mexp, log1pexp (softplus), and sigmoid transformations (#6826)
1 parent f338f10 commit 39a2975

File tree

2 files changed

+112
-4
lines changed

2 files changed

+112
-4
lines changed

pymc/logprob/transforms.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@
6161
Erfc,
6262
Erfcx,
6363
Exp,
64+
Exp2,
65+
Expm1,
6466
Log,
67+
Log1mexp,
68+
Log1p,
69+
Log2,
70+
Log10,
6571
Mul,
6672
Pow,
73+
Sigmoid,
6774
Sinh,
75+
Softplus,
6876
Sqr,
6977
Sqrt,
7078
Tanh,
@@ -82,12 +90,20 @@
8290
erfc,
8391
erfcx,
8492
exp,
93+
exp2,
94+
expm1,
8595
log,
96+
log1mexp,
97+
log1p,
98+
log2,
99+
log10,
86100
mul,
87101
neg,
88102
pow,
89103
reciprocal,
104+
sigmoid,
90105
sinh,
106+
softplus,
91107
sqr,
92108
sqrt,
93109
sub,
@@ -569,8 +585,53 @@ def measurable_sub_to_neg(fgraph, node):
569585
return [pt.add(minuend, pt.neg(subtrahend))]
570586

571587

588+
@node_rewriter([log1p, softplus, log1mexp, log2, log10])
589+
def measurable_special_log_to_log(fgraph, node):
590+
"""Convert log1p, log1mexp, softplus, log2, log10 of `MeasurableVariable`s to log form."""
591+
[inp] = node.inputs
592+
593+
if isinstance(node.op.scalar_op, Log1p):
594+
return [pt.log(1 + inp)]
595+
if isinstance(node.op.scalar_op, Softplus):
596+
return [pt.log(1 + pt.exp(inp))]
597+
if isinstance(node.op.scalar_op, Log1mexp):
598+
return [pt.log(1 - pt.exp(inp))]
599+
if isinstance(node.op.scalar_op, Log2):
600+
return [pt.log(inp) / pt.log(2)]
601+
if isinstance(node.op.scalar_op, Log10):
602+
return [pt.log(inp) / pt.log(10)]
603+
604+
605+
@node_rewriter([expm1, sigmoid, exp2])
606+
def measurable_special_exp_to_exp(fgraph, node):
607+
"""Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form."""
608+
[inp] = node.inputs
609+
if isinstance(node.op.scalar_op, Exp2):
610+
return [pt.exp(pt.log(2) * inp)]
611+
if isinstance(node.op.scalar_op, Expm1):
612+
return [pt.add(pt.exp(inp), -1)]
613+
if isinstance(node.op.scalar_op, Sigmoid):
614+
return [1 / (1 + pt.exp(-inp))]
615+
616+
572617
@node_rewriter(
573-
[exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx]
618+
[
619+
exp,
620+
log,
621+
add,
622+
mul,
623+
pow,
624+
abs,
625+
sinh,
626+
cosh,
627+
tanh,
628+
arcsinh,
629+
arccosh,
630+
arctanh,
631+
erf,
632+
erfc,
633+
erfcx,
634+
]
574635
)
575636
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
576637
"""Find measurable transformations from Elemwise operators."""
@@ -644,7 +705,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
644705
transform = ScaleTransform(
645706
transform_args_fn=lambda *inputs: inputs[-1],
646707
)
647-
648708
transform_op = MeasurableTransform(
649709
scalar_op=scalar_op,
650710
transform=transform,
@@ -692,6 +752,21 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
692752
"transform",
693753
)
694754

755+
measurable_ir_rewrites_db.register(
756+
"measurable_special_log_to_log",
757+
measurable_special_log_to_log,
758+
"basic",
759+
"transform",
760+
)
761+
762+
measurable_ir_rewrites_db.register(
763+
"measurable_special_exp_to_exp",
764+
measurable_special_exp_to_exp,
765+
"basic",
766+
"transform",
767+
)
768+
769+
695770
measurable_ir_rewrites_db.register(
696771
"find_measurable_transforms",
697772
find_measurable_transforms,

tests/logprob/test_transforms.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,8 @@ def test_chained_transform(self):
746746
ArcsinhTransform(),
747747
ArccoshTransform(),
748748
ArctanhTransform(),
749+
LogTransform(),
750+
ExpTransform(),
749751
],
750752
)
751753
def test_check_jac_det(self, transform):
@@ -1105,7 +1107,6 @@ def test_cosh_rv_transform():
11051107
# Something not centered around 0 is usually better
11061108
base_rv = pt.random.normal(0.5, 1, size=(2,), name="base_rv")
11071109
rv = pt.cosh(base_rv)
1108-
11091110
vv = rv.clone()
11101111
rv_logp = logp(rv, vv)
11111112
with pytest.raises(NotImplementedError):
@@ -1118,14 +1119,46 @@ def test_cosh_rv_transform():
11181119
expected_logp = pt.logaddexp(
11191120
logp(base_rv, back_neg), logp(base_rv, back_pos)
11201121
) + transform.log_jac_det(vv)
1121-
11221122
vv_test = np.array([0.25, 1.5])
11231123
np.testing.assert_allclose(
11241124
rv_logp.eval({vv: vv_test}),
11251125
np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf),
11261126
)
11271127

11281128

1129+
TRANSFORMATIONS = {
1130+
"log1p": (pt.log1p, lambda x: pt.log(1 + x)),
1131+
"softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
1132+
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
1133+
"log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
1134+
"log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
1135+
"exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
1136+
"expm1": (pt.expm1, lambda x: pt.exp(x) - 1),
1137+
"sigmoid": (pt.sigmoid, lambda x: 1 / (1 + pt.exp(-x))),
1138+
}
1139+
1140+
1141+
@pytest.mark.parametrize("transform", TRANSFORMATIONS.keys())
1142+
def test_special_log_exp_transforms(transform):
1143+
base_rv = pt.random.normal(name="base_rv")
1144+
vv = pt.scalar("vv")
1145+
1146+
transform_func, ref_func = TRANSFORMATIONS[transform]
1147+
transformed_rv = transform_func(base_rv)
1148+
ref_transformed_rv = ref_func(base_rv)
1149+
1150+
logp_test = logp(transformed_rv, vv)
1151+
logp_ref = logp(ref_transformed_rv, vv)
1152+
1153+
if transform in ["log2", "log10"]:
1154+
# in the cases of log2 and log10 floating point inprecision causes failure
1155+
# from equal_computations so evaluate logp and check all close instead
1156+
vv_test = np.array(0.25)
1157+
np.testing.assert_allclose(logp_ref.eval({vv: vv_test}), logp_test.eval({vv: vv_test}))
1158+
else:
1159+
assert equal_computations([logp_test], [logp_ref])
1160+
1161+
11291162
@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
11301163
@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
11311164
def test_multivariate_rv_transform(shift, scale):

0 commit comments

Comments
 (0)