|
49 | 49 |
|
50 | 50 | from pymc.distributions.transforms import _default_transform, log, logodds
|
51 | 51 | from pymc.logprob.abstract import MeasurableVariable, _logprob
|
52 |
| -from pymc.logprob.basic import conditional_logp, logp |
| 52 | +from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp |
53 | 53 | from pymc.logprob.transforms import (
|
54 | 54 | ArccoshTransform,
|
55 | 55 | ArcsinhTransform,
|
@@ -1080,3 +1080,37 @@ def test_check_jac_det(transform):
|
1080 | 1080 | elemwise=True,
|
1081 | 1081 | rv_var=pt.random.normal(0.5, 1, name="base_rv"),
|
1082 | 1082 | )
|
| 1083 | + |
| 1084 | + |
| 1085 | +def test_logcdf_measurable_transform(): |
| 1086 | + x = pt.exp(pt.random.uniform(0, 1)) |
| 1087 | + value = x.type() |
| 1088 | + logcdf_fn = pytensor.function([value], logcdf(x, value)) |
| 1089 | + |
| 1090 | + assert logcdf_fn(0) == -np.inf |
| 1091 | + np.testing.assert_almost_equal(logcdf_fn(np.exp(0.5)), np.log(0.5)) |
| 1092 | + np.testing.assert_almost_equal(logcdf_fn(5), 0) |
| 1093 | + |
| 1094 | + |
| 1095 | +def test_logcdf_measurable_non_injective_fails(): |
| 1096 | + x = pt.abs(pt.random.uniform(0, 1)) |
| 1097 | + value = x.type() |
| 1098 | + with pytest.raises(NotImplementedError): |
| 1099 | + logcdf(x, value) |
| 1100 | + |
| 1101 | + |
| 1102 | +def test_icdf_measurable_transform(): |
| 1103 | + x = pt.exp(pt.random.uniform(0, 1)) |
| 1104 | + value = x.type() |
| 1105 | + icdf_fn = pytensor.function([value], icdf(x, value)) |
| 1106 | + |
| 1107 | + np.testing.assert_almost_equal(icdf_fn(1e-16), 1) |
| 1108 | + np.testing.assert_almost_equal(icdf_fn(0.5), np.exp(0.5)) |
| 1109 | + np.testing.assert_almost_equal(icdf_fn(1 - 1e-16), np.e) |
| 1110 | + |
| 1111 | + |
| 1112 | +def test_icdf_measurable_non_injective_fails(): |
| 1113 | + x = pt.abs(pt.random.uniform(0, 1)) |
| 1114 | + value = x.type() |
| 1115 | + with pytest.raises(NotImplementedError): |
| 1116 | + icdf(x, value) |
0 commit comments