diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 2b41cb46d8..9457c64ecf 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -85,6 +85,7 @@ def polyagamma_cdf(*args, **kwargs): check_parameters, clipped_beta_rvs, i0e, + log_diff_normal_cdf, log_normal, logpow, normal_lccdf, @@ -743,6 +744,31 @@ def logp(value, mu, sigma, lower, upper): return logp + def logcdf(value, mu, sigma, lower, upper): + logcdf = log_diff_normal_cdf(mu, sigma, value, lower) - log_diff_normal_cdf( + mu, sigma, upper, lower + ) + + is_lower_bounded = not ( + isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)) + ) + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) + + if is_lower_bounded: + logcdf = pt.switch(value < lower, -np.inf, logcdf) + + if is_upper_bounded: + logcdf = pt.switch(value <= upper, logcdf, 0.0) + + if is_lower_bounded and is_upper_bounded: + logcdf = check_parameters( + logcdf, + pt.le(lower, upper), + msg="lower_bound <= upper_bound", + ) + + return logcdf + @_default_transform.register(TruncatedNormal) def truncated_normal_default_transform(op, rv): diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 9bb147e7c9..5c287fd76b 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -37,6 +37,7 @@ Circ, Domain, R, + Rminusbig, Rplus, Rplusbig, Rplusunif, @@ -934,6 +935,11 @@ def scipy_logp(value, mu, sigma, lower, upper): value, (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma ) + def scipy_logcdf(value, mu, sigma, lower, upper): + return st.truncnorm.logcdf( + value, (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma + ) + check_logp( pm.TruncatedNormal, R, @@ -961,6 +967,33 @@ def scipy_logp(value, mu, sigma, lower, upper): skip_paramdomain_outside_edge_test=True, ) + check_logcdf( + pm.TruncatedNormal, + R, + {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig}, + scipy_logcdf, + decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, + ) + + check_logcdf( + pm.TruncatedNormal, + R, + {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, + ft.partial(scipy_logcdf, lower=-np.inf), + decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, + ) + + check_logcdf( + pm.TruncatedNormal, + R, + {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig}, + ft.partial(scipy_logcdf, upper=np.inf), + decimal=select_by_precision(float64=6, float32=1), + skip_paramdomain_outside_edge_test=True, + ) + # This is a regression test for #6128: Check that having one out-of-bound value # in an input array does not set all logp values to -inf dist = pm.TruncatedNormal.dist(mu=1, sigma=2, lower=0, upper=3)