Skip to content

Commit a4aed97

Browse files
authored
Add jitter+full_adapt initialization (#3893)
* Add jitter+full_adapt initialization * Add tests and benchmarks * Actually save file
1 parent ae54ba2 commit a4aed97

File tree

4 files changed

+23
-8
lines changed

4 files changed

+23
-8
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
- GP covariance functions can now be exponentiated by a scalar. See PR [#3852](https://github.com/pymc-devs/pymc3/pull/3852)
1212
- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
1313
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
14-
- `pm.sample` now has support for adapting dense mass matrix using `QuadPotentialFullAdapt` (see [#3596](https://github.com/pymc-devs/pymc3/pull/3596), [#3705](https://github.com/pymc-devs/pymc3/pull/3705) and [#3858](https://github.com/pymc-devs/pymc3/pull/3858))
14+
- `pm.sample` now has support for adapting dense mass matrix using `QuadPotentialFullAdapt` (see [#3596](https://github.com/pymc-devs/pymc3/pull/3596), [#3705](https://github.com/pymc-devs/pymc3/pull/3705), [#3858](https://github.com/pymc-devs/pymc3/pull/3858), and [#3893](https://github.com/pymc-devs/pymc3/pull/3893)). Use `init="adapt_full"` or `init="jitter+adapt_full"` to use.
1515
- `Moyal` distribution added (see [#3870](https://github.com/pymc-devs/pymc3/pull/3870)).
1616
- `pm.LKJCholeskyCov` now automatically computes and returns the unpacked Cholesky decomposition, the correlations and the standard deviations of the covariance matrix (see [#3881](https://github.com/pymc-devs/pymc3/pull/3881)).
1717

benchmarks/benchmarks/benchmarks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class NUTSInitSuite:
146146
"""Tests initializations for NUTS sampler on models
147147
"""
148148
timeout = 360.0
149-
params = ('adapt_diag', 'jitter+adapt_diag', 'advi+adapt_diag_grad')
149+
params = ('adapt_diag', 'jitter+adapt_diag', 'jitter+adapt_full', 'adapt_full')
150150
number = 1
151151
repeat = 1
152152
draws = 10000
@@ -245,7 +245,7 @@ def freefall(y, t, p):
245245
46.48,
246246
48.18
247247
]).reshape(-1, 1)
248-
248+
249249
ode_model = pm.ode.DifferentialEquation(func=freefall, times=times, n_states=1, n_theta=2, t0=0)
250250
with pm.Model() as model:
251251
# Specify prior distributions for some of our model parameters
@@ -255,12 +255,12 @@ def freefall(y, t, p):
255255
ode_solution = ode_model(y0=[0], theta=[gamma, 9.8])
256256
# The ode_solution has a shape of (n_times, n_states)
257257
Y = pm.Normal("Y", mu=ode_solution, sd=sigma, observed=y)
258-
258+
259259
t0 = time.time()
260260
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
261261
tot = time.time() - t0
262262
ess = pm.ess(trace)
263263
return np.mean([ess.sigma, ess.gamma]) / tot
264264

265-
265+
266266
DifferentialEquationSuite.track_1var_2par_ode_ess.unit = 'Effective samples per second'

pymc3/sampling.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,8 +1853,8 @@ def init_nuts(
18531853
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
18541854
variance of the tuning samples. All chains use the test value (usually the prior mean)
18551855
as starting point.
1856-
* jitter+adapt_diag: Same as ``adapt_diag``, but use uniform jitter in [-1, 1] as starting
1857-
point in each chain.
1856+
* jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
1857+
[-1, 1] as starting point in each chain.
18581858
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
18591859
sample variance of the tuning samples.
18601860
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
@@ -1863,7 +1863,10 @@ def init_nuts(
18631863
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
18641864
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
18651865
* map: Use the MAP as starting point. This is discouraged.
1866-
* adapt_full: Adapt a dense mass matrix using the sample covariances
1866+
* adapt_full: Adapt a dense mass matrix using the sample covariances. All chains use the
1867+
test value (usually the prior mean) as starting point.
1868+
* jitter+adapt_full: Same as ``adapt_full`, but use test value plus a uniform jitter in
1869+
[-1, 1] as starting point in each chain.
18671870
chains: int
18681871
Number of jobs to start.
18691872
n_init: int
@@ -2001,6 +2004,16 @@ def init_nuts(
20012004
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
20022005
cov = np.eye(model.ndim)
20032006
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
2007+
elif init == 'jitter+adapt_full':
2008+
start = []
2009+
for _ in range(chains):
2010+
mean = {var: val.copy() for var, val in model.test_point.items()}
2011+
for val in mean.values():
2012+
val[...] += 2 * np.random.rand(*val.shape) - 1
2013+
start.append(mean)
2014+
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
2015+
cov = np.eye(model.ndim)
2016+
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
20042017
else:
20052018
raise ValueError("Unknown initializer: {}.".format(init))
20062019

pymc3/tests/test_sampling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,8 @@ def test_sample_posterior_predictive_w(self):
675675
"advi+adapt_diag_grad",
676676
"map",
677677
"advi_map",
678+
"adapt_full",
679+
"jitter+adapt_full",
678680
],
679681
)
680682
def test_exec_nuts_init(method):

0 commit comments

Comments
 (0)