Skip to content

Commit a2c3d97

Browse files
committed
Normal CDF and test (does not yet pass)
1 parent 8a46358 commit a2c3d97

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

pymc3/distributions/continuous.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from pymc3.theanof import floatX
1616
from . import transforms
1717

18-
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise
18+
from .dist_math import (bound, logpow, gammaln, betaln, std_cdf, i0, i1,
19+
alltrue_elemwise, zvalue)
1920
from .distribution import Continuous, draw_values, generate_samples, Bound
2021

2122
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
@@ -231,6 +232,18 @@ def logp(self, value):
231232
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
232233
sd > 0)
233234

235+
def logcdf(self, value):
236+
mu = self.mu
237+
sd = self.sd
238+
z = zvalue(value, mu=mu, sd=sd)
239+
240+
return tt.switch(
241+
tt.lt(z, -1.0),
242+
tt.log(tt.erfcx(-z / tt.sqrt(2.)) / 2.) -
243+
tt.sqr(tt.abs_(z)) / 2,
244+
tt.log1p(-tt.erfc(z / tt.sqrt(2.)) / 2.)
245+
)
246+
234247

235248
class HalfNormal(PositiveContinuous):
236249
R"""

pymc3/distributions/dist_math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,9 @@ def i1(x):
9696
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
9797
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
9898
+ 14175 / (98304 * x**4)))
99+
100+
def zvalue(value, sd=1, mu=0):
101+
"""
102+
Calculate the z-value for a normal distribution. By default standard normal.
103+
"""
104+
return (value - mu) / tt.sqrt(2 * sd**2)

pymc3/tests/test_distributions.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def pymc3_matches_scipy(self, pymc3_dist, domain, paramdomains, scipy_dist, extr
296296
model = build_model(pymc3_dist, domain, paramdomains, extra_args)
297297
value = model.named_vars['value']
298298

299-
def logp(args):
300-
return scipy_dist(**args)
299+
def logp(kwargs):
300+
return scipy_dist(**kwargs)
301301
self.check_logp(model, value, domain, paramdomains, logp)
302302

303303
def check_logp(self, model, value, domain, paramdomains, logp_reference):
@@ -307,6 +307,17 @@ def check_logp(self, model, value, domain, paramdomains, logp_reference):
307307
for pt in product(domains, n_samples=100):
308308
pt = Point(pt, model=model)
309309
assert_almost_equal(logp(pt), logp_reference(pt), decimal=6, err_msg=str(pt))
310+
311+
def check_logcdf(self, pymc3_dist, domain, paramdomains, scipy_logcdf):
312+
domains = paramdomains.copy()
313+
domains['value'] = domain
314+
for pt in product(domains, n_samples=100):
315+
params = dict(pt)
316+
scipy_cdf = scipy_logcdf(**params)
317+
value = params.pop('value')
318+
dist = pymc3_dist.dist(**params)
319+
assert_almost_equal(dist.logcdf(value).tag.test_value, scipy_cdf,
320+
decimal=6, err_msg=str(pt))
310321

311322
def check_int_to_1(self, model, value, domain, paramdomains):
312323
pdf = model.fastfn(exp(model.logpt))
@@ -379,6 +390,8 @@ def test_flat(self):
379390
def test_normal(self):
380391
self.pymc3_matches_scipy(Normal, R, {'mu': R, 'sd': Rplus},
381392
lambda value, mu, sd: sp.norm.logpdf(value, mu, sd))
393+
self.check_logcdf(Normal, R, {'mu': R, 'sd': Rplus},
394+
lambda value, mu, sd: sp.norm.logcdf(value, mu, sd))
382395

383396
def test_half_normal(self):
384397
self.pymc3_matches_scipy(HalfNormal, Rplus, {'sd': Rplus},

0 commit comments

Comments
 (0)