Skip to content

Commit d4bb701

Browse files
LukeLBLuke LBricardoV94
authored
Derive logprob for hyperbolic and error transformations (#6664)
* cleaned up if block in find_measurable_transform * Adapt default `RVTransform.log_jac_det` to univariate and vector transformations. * Use np.testing in check_jacobian_det Co-authored-by: Luke LB <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 890c2cc commit d4bb701

File tree

3 files changed

+218
-23
lines changed

3 files changed

+218
-23
lines changed

pymc/logprob/transforms.py

Lines changed: 133 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,50 @@
4242
import numpy as np
4343
import pytensor.tensor as pt
4444

45+
from pytensor import scan
4546
from pytensor.gradient import DisconnectedType, jacobian
4647
from pytensor.graph.basic import Apply, Node, Variable
4748
from pytensor.graph.features import AlreadyThere, Feature
4849
from pytensor.graph.fg import FunctionGraph
4950
from pytensor.graph.op import Op
5051
from pytensor.graph.replace import clone_replace
5152
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
52-
from pytensor.scalar import Abs, Add, Exp, Log, Mul, Pow, Sqr, Sqrt
53+
from pytensor.scalar import (
54+
Abs,
55+
Add,
56+
Cosh,
57+
Erf,
58+
Erfc,
59+
Erfcx,
60+
Exp,
61+
Log,
62+
Mul,
63+
Pow,
64+
Sinh,
65+
Sqr,
66+
Sqrt,
67+
Tanh,
68+
)
5369
from pytensor.scan.op import Scan
5470
from pytensor.tensor.exceptions import NotScalarConstantError
5571
from pytensor.tensor.math import (
5672
abs,
5773
add,
74+
cosh,
75+
erf,
76+
erfc,
77+
erfcx,
5878
exp,
5979
log,
6080
mul,
6181
neg,
6282
pow,
6383
reciprocal,
84+
sinh,
6485
sqr,
6586
sqrt,
6687
sub,
88+
tanh,
6789
true_div,
6890
)
6991
from pytensor.tensor.rewriting.basic import (
@@ -122,6 +144,8 @@ def remove_TransformedVariables(fgraph, node):
122144

123145

124146
class RVTransform(abc.ABC):
147+
ndim_supp = None
148+
125149
@abc.abstractmethod
126150
def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
127151
"""Apply the transformation."""
@@ -135,12 +159,16 @@ def backward(
135159

136160
def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
137161
"""Construct the log of the absolute value of the Jacobian determinant."""
138-
# jac = pt.reshape(
139-
# gradient(pt.sum(self.backward(value, *inputs)), [value]), value.shape
140-
# )
141-
# return pt.log(pt.abs(jac))
142-
phi_inv = self.backward(value, *inputs)
143-
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))
162+
if self.ndim_supp not in (0, 1):
163+
raise NotImplementedError(
164+
f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}"
165+
)
166+
if self.ndim_supp == 0:
167+
jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape)
168+
return pt.log(pt.abs(jac))
169+
else:
170+
phi_inv = self.backward(value, *inputs)
171+
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))
144172

145173

146174
@node_rewriter(tracks=None)
@@ -340,7 +368,7 @@ def apply(self, fgraph: FunctionGraph):
340368
class MeasurableTransform(MeasurableElemwise):
341369
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
342370

343-
valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs)
371+
valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx)
344372

345373
# Cannot use `transform` as name because it would clash with the property added by
346374
# the `TransformValuesRewrite`
@@ -540,7 +568,7 @@ def measurable_sub_to_neg(fgraph, node):
540568
return [pt.add(minuend, pt.neg(subtrahend))]
541569

542570

543-
@node_rewriter([exp, log, add, mul, pow, abs])
571+
@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx])
544572
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
545573
"""Find measurable transformations from Elemwise operators."""
546574

@@ -585,13 +613,20 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
585613
measurable_input_idx = 0
586614
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
587615
transform: RVTransform
588-
if isinstance(scalar_op, Exp):
589-
transform = ExpTransform()
590-
elif isinstance(scalar_op, Log):
591-
transform = LogTransform()
592-
elif isinstance(scalar_op, Abs):
593-
transform = AbsTransform()
594-
elif isinstance(scalar_op, Pow):
616+
617+
transform_dict = {
618+
Exp: ExpTransform(),
619+
Log: LogTransform(),
620+
Abs: AbsTransform(),
621+
Sinh: SinhTransform(),
622+
Cosh: CoshTransform(),
623+
Tanh: TanhTransform(),
624+
Erf: ErfTransform(),
625+
Erfc: ErfcTransform(),
626+
Erfcx: ErfcxTransform(),
627+
}
628+
transform = transform_dict.get(type(scalar_op), None)
629+
if isinstance(scalar_op, Pow):
595630
# We only allow for the base to be measurable
596631
if measurable_input_idx != 0:
597632
return None
@@ -608,7 +643,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
608643
transform = LocTransform(
609644
transform_args_fn=lambda *inputs: inputs[-1],
610645
)
611-
else:
646+
elif transform is None:
612647
transform_inputs = (measurable_input, pt.mul(*other_inputs))
613648
transform = ScaleTransform(
614649
transform_args_fn=lambda *inputs: inputs[-1],
@@ -671,6 +706,87 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
671706
)
672707

673708

709+
class SinhTransform(RVTransform):
710+
name = "sinh"
711+
ndim_supp = 0
712+
713+
def forward(self, value, *inputs):
714+
return pt.sinh(value)
715+
716+
def backward(self, value, *inputs):
717+
return pt.arcsinh(value)
718+
719+
720+
class CoshTransform(RVTransform):
721+
name = "cosh"
722+
ndim_supp = 0
723+
724+
def forward(self, value, *inputs):
725+
return pt.cosh(value)
726+
727+
def backward(self, value, *inputs):
728+
return pt.arccosh(value)
729+
730+
731+
class TanhTransform(RVTransform):
732+
name = "tanh"
733+
ndim_supp = 0
734+
735+
def forward(self, value, *inputs):
736+
return pt.tanh(value)
737+
738+
def backward(self, value, *inputs):
739+
return pt.arctanh(value)
740+
741+
742+
class ErfTransform(RVTransform):
743+
name = "erf"
744+
ndim_supp = 0
745+
746+
def forward(self, value, *inputs):
747+
return pt.erf(value)
748+
749+
def backward(self, value, *inputs):
750+
return pt.erfinv(value)
751+
752+
753+
class ErfcTransform(RVTransform):
754+
name = "erfc"
755+
ndim_supp = 0
756+
757+
def forward(self, value, *inputs):
758+
return pt.erfc(value)
759+
760+
def backward(self, value, *inputs):
761+
return pt.erfcinv(value)
762+
763+
764+
class ErfcxTransform(RVTransform):
765+
name = "erfcx"
766+
ndim_supp = 0
767+
768+
def forward(self, value, *inputs):
769+
return pt.erfcx(value)
770+
771+
def backward(self, value, *inputs):
772+
# computes the inverse of erfcx, this was adapted from
773+
# https://tinyurl.com/4mxfd3cz
774+
x = pt.switch(value <= 1, 1.0 / (value * pt.sqrt(np.pi)), -pt.sqrt(pt.log(value)))
775+
776+
def calc_delta_x(value, prior_result):
777+
return prior_result - (pt.erfcx(prior_result) - value) / (
778+
2 * prior_result * pt.erfcx(prior_result) - 2 / pt.sqrt(np.pi)
779+
)
780+
781+
result, updates = scan(
782+
fn=calc_delta_x,
783+
outputs_info=pt.ones_like(x),
784+
non_sequences=value,
785+
n_steps=10,
786+
)
787+
return result[-1]
788+
789+
674790
class LocTransform(RVTransform):
675791
name = "loc"
676792

tests/distributions/test_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def check_jacobian_det(
112112
)
113113

114114
for yval in domain.vals:
115-
close_to(actual_ljd(yval), computed_ljd(yval), tol)
115+
np.testing.assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol)
116116

117117

118118
def test_simplex():

tests/logprob/test_transforms.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,22 @@
4949

5050
from pymc.distributions.transforms import _default_transform, log, logodds
5151
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
52-
from pymc.logprob.basic import factorized_joint_logprob
52+
from pymc.logprob.basic import factorized_joint_logprob, logp
5353
from pymc.logprob.transforms import (
5454
ChainedTransform,
55+
CoshTransform,
56+
ErfcTransform,
57+
ErfcxTransform,
58+
ErfTransform,
5559
ExpTransform,
5660
IntervalTransform,
5761
LocTransform,
5862
LogOddsTransform,
5963
LogTransform,
6064
RVTransform,
6165
ScaleTransform,
66+
SinhTransform,
67+
TanhTransform,
6268
TransformValuesMapping,
6369
TransformValuesRewrite,
6470
transformed_variable,
@@ -327,6 +333,7 @@ def test_fallback_log_jac_det(ndim):
327333

328334
class SquareTransform(RVTransform):
329335
name = "square"
336+
ndim_supp = ndim
330337

331338
def forward(self, value, *inputs):
332339
return pt.power(value, 2)
@@ -336,13 +343,31 @@ def backward(self, value, *inputs):
336343

337344
square_tr = SquareTransform()
338345

339-
value = pt.TensorType("float64", (None,) * ndim)("value")
346+
value = pt.vector("value")
340347
value_tr = square_tr.forward(value)
341348
log_jac_det = square_tr.log_jac_det(value_tr)
342349

343-
test_value = np.full((2,) * ndim, 3)
344-
expected_log_jac_det = -np.log(6) * test_value.size
345-
assert np.isclose(log_jac_det.eval({value: test_value}), expected_log_jac_det)
350+
test_value = np.r_[3, 4]
351+
expected_log_jac_det = -np.log(2 * test_value)
352+
if ndim == 1:
353+
expected_log_jac_det = expected_log_jac_det.sum()
354+
np.testing.assert_array_equal(log_jac_det.eval({value: test_value}), expected_log_jac_det)
355+
356+
357+
@pytest.mark.parametrize("ndim", (None, 2))
358+
def test_fallback_log_jac_det_undefined_ndim(ndim):
359+
class SquareTransform(RVTransform):
360+
name = "square"
361+
ndim_supp = ndim
362+
363+
def forward(self, value, *inputs):
364+
return pt.power(value, 2)
365+
366+
def backward(self, value, *inputs):
367+
return pt.sqrt(value)
368+
369+
with pytest.raises(NotImplementedError, match=r"only implemented for ndim_supp in \(0, 1\)"):
370+
SquareTransform().log_jac_det(0)
346371

347372

348373
def test_hierarchical_uniform_transform():
@@ -989,3 +1014,57 @@ def test_multivariate_transform(shift, scale):
9891014
scale_mat @ cov @ scale_mat.T,
9901015
),
9911016
)
1017+
1018+
1019+
@pytest.mark.parametrize(
1020+
"pt_transform, transform",
1021+
[
1022+
(pt.erf, ErfTransform()),
1023+
(pt.erfc, ErfcTransform()),
1024+
(pt.erfcx, ErfcxTransform()),
1025+
(pt.sinh, SinhTransform()),
1026+
(pt.cosh, CoshTransform()),
1027+
(pt.tanh, TanhTransform()),
1028+
],
1029+
)
1030+
def test_erf_logp(pt_transform, transform):
1031+
base_rv = pt.random.normal(
1032+
0.5, 1, name="base_rv"
1033+
) # Something not centered around 0 is usually better
1034+
rv = pt_transform(base_rv)
1035+
1036+
vv = rv.clone()
1037+
rv_logp = logp(rv, vv)
1038+
1039+
expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv)
1040+
1041+
vv_test = np.array(0.25) # Arbitrary test value
1042+
np.testing.assert_almost_equal(
1043+
rv_logp.eval({vv: vv_test}), np.nan_to_num(expected_logp.eval({vv: vv_test}), nan=-np.inf)
1044+
)
1045+
1046+
1047+
from pymc.testing import Rplusbig, Vector
1048+
from tests.distributions.test_transform import check_jacobian_det
1049+
1050+
1051+
@pytest.mark.parametrize(
1052+
"transform",
1053+
[
1054+
ErfTransform(),
1055+
ErfcTransform(),
1056+
ErfcxTransform(),
1057+
SinhTransform(),
1058+
CoshTransform(),
1059+
TanhTransform(),
1060+
],
1061+
)
1062+
def test_check_jac_det(transform):
1063+
check_jacobian_det(
1064+
transform,
1065+
Vector(Rplusbig, 2),
1066+
pt.dvector,
1067+
[0.1, 0.1],
1068+
elemwise=True,
1069+
rv_var=pt.random.normal(0.5, 1, name="base_rv"),
1070+
)

0 commit comments

Comments
 (0)