Skip to content

Commit 5236d3e

Browse files
authored
Fix bug in which TruncatedNormal returns -inf for all values if any value is out of bounds (#6128)
* use switch instead of relying on check_parameters
1 parent 16bee91 commit 5236d3e

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

pymc/distributions/continuous.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -760,31 +760,40 @@ def logp(
760760
-------
761761
TensorVariable
762762
"""
763-
unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf)
764-
unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf)
763+
is_lower_bounded = not (
764+
isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))
765+
)
766+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
765767

766-
if not unbounded_lower and not unbounded_upper:
768+
if is_lower_bounded and is_upper_bounded:
767769
lcdf_a = normal_lcdf(mu, sigma, lower)
768770
lcdf_b = normal_lcdf(mu, sigma, upper)
769771
lsf_a = normal_lccdf(mu, sigma, lower)
770772
lsf_b = normal_lccdf(mu, sigma, upper)
771773
norm = at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a))
772-
elif not unbounded_lower:
774+
elif is_lower_bounded:
773775
norm = normal_lccdf(mu, sigma, lower)
774-
elif not unbounded_upper:
776+
elif is_upper_bounded:
775777
norm = normal_lcdf(mu, sigma, upper)
776778
else:
777779
norm = 0.0
778780

779781
logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm
780-
bounds = []
781-
if not unbounded_lower:
782-
bounds.append(value >= lower)
783-
if not unbounded_upper:
784-
bounds.append(value <= upper)
785-
if not unbounded_lower and not unbounded_upper:
786-
bounds.append(lower <= upper)
787-
return check_parameters(logp, *bounds)
782+
783+
if is_lower_bounded:
784+
logp = at.switch(value < lower, -np.inf, logp)
785+
786+
if is_upper_bounded:
787+
logp = at.switch(value > upper, -np.inf, logp)
788+
789+
if is_lower_bounded and is_upper_bounded:
790+
logp = check_parameters(
791+
logp,
792+
at.le(lower, upper),
793+
msg="lower_bound <= upper_bound",
794+
)
795+
796+
return logp
788797

789798

790799
@_default_transform.register(TruncatedNormal)

pymc/tests/distributions/test_continuous.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,14 @@ def scipy_logp(value, mu, sigma, lower, upper):
858858
skip_paramdomain_outside_edge_test=True,
859859
)
860860

861+
# This is a regression test for #6128: Check that having one out-of-bound value
862+
# in an input array does not set all logp values to -inf
863+
dist = pm.TruncatedNormal.dist(mu=1, sigma=2, lower=0, upper=3)
864+
logp = pm.logp(dist, [-2.0, 1.0, 4.0]).eval()
865+
assert np.isinf(logp[0])
866+
assert np.isfinite(logp[1])
867+
assert np.isinf(logp[2])
868+
861869
def test_get_tau_sigma(self):
862870
sigma = np.array(2)
863871
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])

0 commit comments

Comments
 (0)