Skip to content

Add logcdf implementation for Truncated distributions #6690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,50 @@ def truncated_logprob(op, values, *inputs, **kwargs):
return logp


@_logcdf.register(TruncatedRV)
def truncated_logcdf(op, value, *inputs, **kwargs):
*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]

base_rv_op = op.base_rv_op
logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)

# For left truncated discrete RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))

lognorm = 0
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf

logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
logcdf_trunc = logcdf_numerator - lognorm

if is_lower_bounded:
logcdf_trunc = pt.switch(value < lower, -np.inf, logcdf_trunc)

if is_upper_bounded:
logcdf_trunc = pt.switch(value <= upper, logcdf_trunc, 0.0)

if is_lower_bounded and is_upper_bounded:
logcdf_trunc = check_parameters(
logcdf_trunc,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logcdf_trunc


@_truncated.register(NormalRV)
def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma):
return TruncatedNormal.dist(
Expand Down
62 changes: 61 additions & 1 deletion tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _icdf
from pymc.logprob.basic import logp
from pymc.logprob.basic import logcdf, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.logprob.utils import ParameterValueError
from pymc.testing import assert_moment_is_expected
Expand Down Expand Up @@ -165,6 +165,34 @@ def test_truncation_continuous_logp(op_type, lower, upper):
assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v))


@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_continuous_logcdf(op_type, lower, upper):
loc = 0.15
scale = 10
op = icdf_normal if op_type == "icdf" else rejection_normal

x = op(loc, scale, name="x")
xt = Truncated.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)

xt_vv = xt.clone()
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))

ref_xt = scipy.stats.truncnorm(
(lower - loc) / scale,
(upper - loc) / scale,
loc,
scale,
)
for bound in (lower, upper):
if np.isinf(bound):
return
for offset in (-1, 0, 1):
test_xt_v = bound + offset
assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt.logcdf(test_xt_v))


@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_discrete_random(op_type, lower, upper):
Expand Down Expand Up @@ -232,6 +260,38 @@ def ref_xt_logpmf(value):
assert np.isclose(log_integral, 0.0, atol=1e-5)


@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)])
@pytest.mark.parametrize("op_type", ["icdf", "rejection"])
def test_truncation_discrete_logcdf(op_type, lower, upper):
p = 0.7
op = icdf_geometric if op_type == "icdf" else rejection_geometric

x = op(p, name="x")
xt = Truncated.dist(x, lower=lower, upper=upper)
assert isinstance(xt.owner.op, TruncatedRV)

xt_vv = xt.clone()
xt_logcdf_fn = pytensor.function([xt_vv], logcdf(xt, xt_vv))

ref_xt = scipy.stats.geom(p)
log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1))

def ref_xt_logcdf(value):
if value < lower:
return -np.inf
elif value > upper:
return 0.0

return np.log(ref_xt.cdf(value) - ref_xt.cdf(lower - 1)) - log_norm

for bound in (lower, upper):
if np.isinf(bound):
continue
for offset in (-1, 0, 1):
test_xt_v = bound + offset
assert np.isclose(xt_logcdf_fn(test_xt_v), ref_xt_logcdf(test_xt_v))


def test_truncation_exceptions():
with pytest.raises(ValueError, match="lower and upper cannot both be None"):
Truncated.dist(pt.random.normal())
Expand Down