Skip to content

Commit cce9dfe

Browse files
author
Junpeng Lao
authored
Merge pull request #2163 from a-rodin/interpolated
Add Interpolated distribution class
2 parents 1253da3 + 4a381f1 commit cce9dfe

File tree

6 files changed

+246
-78
lines changed

6 files changed

+246
-78
lines changed

docs/source/notebooks/updating_priors.ipynb

Lines changed: 91 additions & 74 deletions
Large diffs are not rendered by default.

pymc3/distributions/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .continuous import SkewNormal
2626
from .continuous import Triangular
2727
from .continuous import Gumbel
28+
from .continuous import Interpolated
2829

2930
from .discrete import Binomial
3031
from .discrete import BetaBinomial
@@ -132,5 +133,6 @@
132133
'NormalMixture',
133134
'Triangular',
134135
'DiscreteWeibull',
135-
'Gumbel'
136+
'Gumbel',
137+
'Interpolated'
136138
]

pymc3/distributions/continuous.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
import numpy as np
1111
import theano.tensor as tt
1212
from scipy import stats
13+
from scipy.interpolate import InterpolatedUnivariateSpline
1314
import warnings
1415

1516
from pymc3.theanof import floatX
1617
from . import transforms
1718

18-
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise
19+
from .dist_math import bound, logpow, gammaln, betaln, std_cdf, i0, i1, alltrue_elemwise, DifferentiableSplineWrapper
1920
from .distribution import Continuous, draw_values, generate_samples, Bound
2021

2122
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
2223
'StudentT', 'Cauchy', 'HalfCauchy', 'Gamma', 'Weibull',
2324
'HalfStudentT', 'StudentTpos', 'Lognormal', 'ChiSquared',
2425
'HalfNormal', 'Wald', 'Pareto', 'InverseGamma', 'ExGaussian',
25-
'VonMises', 'SkewNormal']
26+
'VonMises', 'SkewNormal', 'Interpolated']
2627

2728

2829
class PositiveContinuous(Continuous):
@@ -1389,3 +1390,71 @@ def random(self, point=None, size=None, repeat=None):
13891390
def logp(self, value):
13901391
scaled = (value - self.mu) / self.beta
13911392
return bound(-scaled - tt.exp(-scaled) - tt.log(self.beta), self.beta > 0)
1393+
1394+
class Interpolated(Continuous):
1395+
R"""
1396+
Probability distribution defined as a linear interpolation of
1397+
of a set of points and values of probability density function
1398+
evaluated on them.
1399+
1400+
The points are not variables, but plain array-like objects, so
1401+
they are constant and cannot be sampled.
1402+
1403+
======== =========================================
1404+
Support :math:`x \in [x_points[0], x_points[-1]]`
1405+
======== =========================================
1406+
1407+
Parameters
1408+
----------
1409+
x_points : array-like
1410+
A monotonically growing list of values
1411+
pdf_points : array-like
1412+
Probability density function evaluated at points from `x`
1413+
"""
1414+
1415+
def __init__(self, x_points, pdf_points, transform='interval',
1416+
*args, **kwargs):
1417+
if transform == 'interval':
1418+
transform = transforms.interval(x_points[0], x_points[-1])
1419+
super(Interpolated, self).__init__(transform=transform,
1420+
*args, **kwargs)
1421+
1422+
interp = InterpolatedUnivariateSpline(x_points, pdf_points, k=1, ext='zeros')
1423+
Z = interp.integral(x_points[0], x_points[-1])
1424+
1425+
self.Z = tt.as_tensor_variable(Z)
1426+
self.interp_op = DifferentiableSplineWrapper(interp)
1427+
self.x_points = x_points
1428+
self.pdf_points = pdf_points / Z
1429+
self.cdf_points = interp.antiderivative()(x_points) / Z
1430+
1431+
self.median = self._argcdf(0.5)
1432+
1433+
def _argcdf(self, p):
1434+
pdf = self.pdf_points
1435+
cdf = self.cdf_points
1436+
x = self.x_points
1437+
1438+
index = np.searchsorted(cdf, p) - 1
1439+
slope = (pdf[index + 1] - pdf[index]) / (x[index + 1] - x[index])
1440+
1441+
return x[index] + np.where(
1442+
np.abs(slope) <= 1e-8,
1443+
np.where(
1444+
np.abs(pdf[index]) <= 1e-8,
1445+
np.zeros(index.shape),
1446+
(p - cdf[index]) / pdf[index]
1447+
),
1448+
(-pdf[index] + np.sqrt(pdf[index] ** 2 + 2 * slope * (p - cdf[index]))) / slope
1449+
)
1450+
1451+
def _random(self, size=None):
1452+
return self._argcdf(np.random.uniform(size=size))
1453+
1454+
def random(self, point=None, size=None, repeat=None):
1455+
return generate_samples(self._random,
1456+
dist_shape=self.shape,
1457+
size=size)
1458+
1459+
def logp(self, value):
1460+
return tt.log(self.interp_op(value) / self.Z)

pymc3/distributions/dist_math.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,35 @@ def conjugate_solve_triangular(outer, inner):
364364
else:
365365
grad = tt.triu(s + s.T) - tt.diag(tt.diagonal(s))
366366
return [tt.switch(ok, grad, floatX(np.nan))]
367+
368+
class SplineWrapper (theano.Op):
369+
"""
370+
Creates a theano operation from scipy.interpolate.UnivariateSpline
371+
"""
372+
373+
__props__ = ('spline',)
374+
itypes = [tt.dscalar]
375+
otypes = [tt.dscalar]
376+
377+
def __init__(self, spline):
378+
self.spline = spline
379+
380+
def perform(self, node, inputs, output_storage):
381+
x, = inputs
382+
output_storage[0][0] = np.asarray(self.spline(x))
383+
384+
class DifferentiableSplineWrapper (SplineWrapper):
385+
"""
386+
Creates a theano operation with defined gradient from
387+
scipy.interpolate.UnivariateSpline
388+
"""
389+
390+
def __init__(self, spline):
391+
super(DifferentiableSplineWrapper, self).__init__(spline)
392+
self.spline_grad = SplineWrapper(spline.derivative())
393+
self.__props__ += ('spline_grad',)
394+
395+
def grad(self, inputs, grads):
396+
x, = inputs
397+
x_grad, = grads
398+
return [x_grad * self.spline_grad(x)]

pymc3/tests/test_distributions.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
16-
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, Gumbel)
16+
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, Gumbel,
17+
Interpolated)
1718
from ..distributions import continuous
1819
from pymc3.theanof import floatX
1920
from numpy import array, inf, log, exp
@@ -791,3 +792,30 @@ def test_gumbel(self):
791792
def test_multidimensional_beta_construction(self):
792793
with Model():
793794
Beta('beta', alpha=1., beta=1., shape=(10, 20))
795+
796+
def test_interpolated(self):
797+
for mu in R.vals:
798+
for sd in Rplus.vals:
799+
#pylint: disable=cell-var-from-loop
800+
xmin = mu - 5 * sd
801+
xmax = mu + 5 * sd
802+
803+
class TestedInterpolated (Interpolated):
804+
805+
def __init__(self, **kwargs):
806+
x_points = np.linspace(xmin, xmax, 100000)
807+
pdf_points = sp.norm.pdf(x_points, loc=mu, scale=sd)
808+
super(TestedInterpolated, self).__init__(
809+
x_points=x_points,
810+
pdf_points=pdf_points,
811+
**kwargs
812+
)
813+
814+
def ref_pdf(value):
815+
return np.where(
816+
np.logical_and(value >= xmin, value <= xmax),
817+
sp.norm.logpdf(value, mu, sd),
818+
-np.inf * np.ones(value.shape)
819+
)
820+
821+
self.pymc3_matches_scipy(TestedInterpolated, R, {}, ref_pdf)

pymc3/tests/test_distributions_random.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,26 @@ def ref_rand(size, mu, beta):
577577
return st.gumbel_r.rvs(loc=mu, scale=beta, size=size)
578578
pymc3_random(pm.Gumbel, {'mu': R, 'beta': Rplus}, ref_rand=ref_rand)
579579

580+
def test_interpolated(self):
581+
for mu in R.vals:
582+
for sd in Rplus.vals:
583+
#pylint: disable=cell-var-from-loop
584+
def ref_rand(size):
585+
return st.norm.rvs(loc=mu, scale=sd, size=size)
586+
587+
class TestedInterpolated (pm.Interpolated):
588+
589+
def __init__(self, **kwargs):
590+
x_points = np.linspace(mu - 5 * sd, mu + 5 * sd, 100)
591+
pdf_points = st.norm.pdf(x_points, loc=mu, scale=sd)
592+
super(TestedInterpolated, self).__init__(
593+
x_points=x_points,
594+
pdf_points=pdf_points,
595+
**kwargs
596+
)
597+
598+
pymc3_random(TestedInterpolated, {}, ref_rand=ref_rand)
599+
580600
@pytest.mark.skip('Wishart random sampling not implemented.\n'
581601
'See https://github.com/pymc-devs/pymc3/issues/538')
582602
def test_wishart(self):

0 commit comments

Comments
 (0)