Skip to content

Commit 562fe16

Browse files
committed
NotImplementedError for icdf of non-injective MeasurableTransforms
1 parent 4847914 commit 562fe16

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

pymc/logprob/transforms.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
448448

449449
backward_value = op.transform_elemwise.backward(value, *other_inputs)
450450

451-
# Some transformations, like squaring may produce multiple backward values
451+
# Fail if transformation is not injective
452+
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
452453
if isinstance(backward_value, tuple):
453454
raise NotImplementedError
454455

@@ -469,6 +470,11 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
469470
input_icdf = _icdf_helper(measurable_input, value)
470471
icdf = op.transform_elemwise.forward(input_icdf, *other_inputs)
471472

473+
# Fail if transformation is not injective
474+
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
475+
if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple):
476+
raise NotImplementedError
477+
472478
return icdf
473479

474480

tests/logprob/test_transforms.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
from pymc.distributions.transforms import _default_transform, log, logodds
5151
from pymc.logprob.abstract import MeasurableVariable, _logprob
52-
from pymc.logprob.basic import conditional_logp, logp
52+
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
5353
from pymc.logprob.transforms import (
5454
ArccoshTransform,
5555
ArcsinhTransform,
@@ -1080,3 +1080,37 @@ def test_check_jac_det(transform):
10801080
elemwise=True,
10811081
rv_var=pt.random.normal(0.5, 1, name="base_rv"),
10821082
)
1083+
1084+
1085+
def test_logcdf_measurable_transform():
1086+
x = pt.exp(pt.random.uniform(0, 1))
1087+
value = x.type()
1088+
logcdf_fn = pytensor.function([value], logcdf(x, value))
1089+
1090+
assert logcdf_fn(0) == -np.inf
1091+
np.testing.assert_almost_equal(logcdf_fn(np.exp(0.5)), np.log(0.5))
1092+
np.testing.assert_almost_equal(logcdf_fn(5), 0)
1093+
1094+
1095+
def test_logcdf_measurable_non_injective_fails():
1096+
x = pt.abs(pt.random.uniform(0, 1))
1097+
value = x.type()
1098+
with pytest.raises(NotImplementedError):
1099+
logcdf(x, value)
1100+
1101+
1102+
def test_icdf_measurable_transform():
1103+
x = pt.exp(pt.random.uniform(0, 1))
1104+
value = x.type()
1105+
icdf_fn = pytensor.function([value], icdf(x, value))
1106+
1107+
np.testing.assert_almost_equal(icdf_fn(1e-16), 1)
1108+
np.testing.assert_almost_equal(icdf_fn(0.5), np.exp(0.5))
1109+
np.testing.assert_almost_equal(icdf_fn(1 - 1e-16), np.e)
1110+
1111+
1112+
def test_icdf_measurable_non_injective_fails():
1113+
x = pt.abs(pt.random.uniform(0, 1))
1114+
value = x.type()
1115+
with pytest.raises(NotImplementedError):
1116+
icdf(x, value)

0 commit comments

Comments
 (0)