diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 05616fbe4c..f44de3c018 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -352,6 +352,50 @@ def truncated_logprob(op, values, *inputs, **kwargs): return logp +@_logcdf.register(TruncatedRV) +def truncated_logcdf(op, value, *inputs, **kwargs): + *rv_inputs, lower, upper, rng = inputs + rv_inputs = [rng, *rv_inputs] + + base_rv_op = op.base_rv_op + logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs) + + # For left truncated discrete RVs, we don't want to include the lower bound in the + # normalization term + lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower + lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) + upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) + + is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) + + lognorm = 0 + if is_lower_bounded and is_upper_bounded: + lognorm = logdiffexp(upper_logcdf, lower_logcdf) + elif is_lower_bounded: + lognorm = pt.log1mexp(lower_logcdf) + elif is_upper_bounded: + lognorm = upper_logcdf + + logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf + logcdf_trunc = logcdf_numerator - lognorm + + if is_lower_bounded: + logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc) + + if is_upper_bounded: + logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0) + + if is_lower_bounded and is_upper_bounded: + logcdf_trunc = check_parameters( + logcdf_trunc, + pt.le(lower, upper), + msg="lower_bound <= upper_bound", + ) + + return logcdf_trunc + + @_truncated.register(NormalRV) def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma): return TruncatedNormal.dist( diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index 7502260dc8..53e933285b 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -26,7 +26,7 @@ from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated from pymc.exceptions import TruncationError from pymc.logprob.abstract import _icdf -from pymc.logprob.basic import logp +from pymc.logprob.basic import logcdf, logp from pymc.logprob.transforms import IntervalTransform from pymc.logprob.utils import ParameterValueError from pymc.testing import assert_moment_is_expected @@ -165,6 +165,34 @@ def test_truncation_continuous_logp(op_type, lower, upper): assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v)) +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_continuous_logcdf(op_type, lower, upper): + loc = 0.15 + scale = 10 + op = icdf_normal if op_type == "icdf" else rejection_normal + + x = op(loc, scale, name="x") + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + xt_vv = xt.clone() + xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv)) + + ref_xt = scipy.stats.truncnorm( + (lower - loc) / scale, + (upper - loc) / scale, + loc, + scale, + ) + for bound in (lower, upper): + if np.isinf(bound): + return + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt.logcdf(test_xt_v)) + + @pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) def test_truncation_discrete_random(op_type, lower, upper): @@ -232,6 +260,38 @@ def ref_xt_logpmf(value): assert np.isclose(log_integral, 0.0, atol=1e-5) +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_discrete_logcdf(op_type, lower, upper): + p = 0.7 + op = icdf_geometric if op_type == "icdf" else rejection_geometric + + x = op(p, name="x") + xt = Truncated.dist(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + xt_vv = xt.clone() + xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv)) + + ref_xt = scipy.stats.geom(p) + log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1)) + + def ref_xt_logcdf(value): + if value < lower: + return -np.inf + elif value > upper: + return 0.0 + + return np.log(ref_xt.cdf(value) - ref_xt.cdf(lower - 1)) - log_norm + + for bound in (lower, upper): + if np.isinf(bound): + continue + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt_logcdf(test_xt_v)) + + def test_truncation_exceptions(): with pytest.raises(ValueError, match="lower and upper cannot both be None"): Truncated.dist(pt.random.normal())