Skip to content

Commit a222c86

Browse files
aloctavodiaJunpeng Lao
authored and
Junpeng Lao
committed
Update Loo, implement improved algorithm (#2730)
* update loo * small fixes * remove print * remove unused import * add to release-notes * automatic reff calculation * fix eff_ave
1 parent 250cc00 commit a222c86

File tree

2 files changed

+173
-48
lines changed

2 files changed

+173
-48
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
### New features
77

88
- Improve NUTS initialization `advi+adapt_diag_grad` and add `jitter+adapt_diag_grad` (#2643)
9-
9+
- Update loo, new improved algorithm (#2730)
10+
1011
### Fixes
1112
- Fixed `compareplot` to use `loo` output.
1213
- Add test for `model.logp_array` and `model.bijection` (#2724)

pymc3/stats.py

Lines changed: 171 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from scipy.misc import logsumexp
1515
from scipy.stats import dirichlet
16-
from scipy.stats.distributions import pareto
1716
from scipy.optimize import minimize
1817

1918
from .backends import tracetab as ttab
@@ -235,10 +234,10 @@ def waic(trace, model=None, pointwise=False, progressbar=False):
235234
return WAIC_r(waic, waic_se, p_waic)
236235

237236

238-
def loo(trace, model=None, pointwise=False, progressbar=False):
239-
"""Calculates leave-one-out (LOO) cross-validation for out of sample predictive
240-
model fit, following Vehtari et al. (2015). Cross-validation is computed using
241-
Pareto-smoothed importance sampling (PSIS).
237+
def loo(trace, model=None, pointwise=False, reff=None, progressbar=False):
238+
"""Calculates leave-one-out (LOO) cross-validation for out of sample
239+
predictive model fit, following Vehtari et al. (2015). Cross-validation is
240+
computed using Pareto-smoothed importance sampling (PSIS).
242241
243242
Parameters
244243
----------
@@ -248,6 +247,10 @@ def loo(trace, model=None, pointwise=False, progressbar=False):
248247
pointwise: bool
249248
if True the pointwise predictive accuracy will be returned.
250249
Default False
250+
reff : float
251+
relative MCMC efficiency, `effective_n / n` i.e. number of effective
252+
samples divided by the number of actual samples. Computed from trace by
253+
default.
251254
progressbar: bool
252255
Whether or not to display a progress bar in the command line. The
253256
bar shows the percentage of completion, the evaluation speed, and
@@ -259,61 +262,37 @@ def loo(trace, model=None, pointwise=False, progressbar=False):
259262
loo: approximated Leave-one-out cross-validation
260263
loo_se: standard error of loo
261264
p_loo: effective number of parameters
262-
loo_i: and array of the pointwise predictive accuracy, only if pointwise True
265+
loo_i: array of pointwise predictive accuracy, only if pointwise True
263266
"""
264267
model = modelcontext(model)
265268

269+
if reff is None:
270+
if trace.nchains == 1:
271+
reff = 1.
272+
else:
273+
eff = pm.effective_n(trace)
274+
eff_ave = pm.stats.dict2pd(eff, 'eff').mean()
275+
samples = len(trace) * trace.nchains
276+
reff = eff_ave / samples
277+
266278
log_py = _log_post_trace(trace, model, progressbar=progressbar)
267279
if log_py.size == 0:
268280
raise ValueError('The model does not contain observed values.')
269281

270-
# Importance ratios
271-
r = np.exp(-log_py)
272-
r_sorted = np.sort(r, axis=0)
273-
274-
# Extract largest 20% of importance ratios and fit generalized Pareto to each
275-
# (returns tuple with shape, location, scale)
276-
q80 = int(len(log_py) * 0.8)
277-
pareto_fit = np.apply_along_axis(
278-
lambda x: pareto.fit(x, floc=0), 0, r_sorted[q80:])
279-
280-
if np.any(pareto_fit[0] > 0.7):
282+
lw, ks = _psislw(-log_py, reff)
283+
lw += log_py
284+
if np.any(ks > 0.7):
281285
warnings.warn("""Estimated shape parameter of Pareto distribution is
282286
greater than 0.7 for one or more samples.
283-
You should consider using a more robust model, this is
284-
because importance sampling is less likely to work well if the marginal
287+
You should consider using a more robust model, this is because
288+
importance sampling is less likely to work well if the marginal
285289
posterior and LOO posterior are very different. This is more likely to
286290
happen with a non-robust model and highly influential observations.""")
287291

288-
elif np.any(pareto_fit[0] > 0.5):
289-
warnings.warn("""Estimated shape parameter of Pareto distribution is
290-
greater than 0.5 for one or more samples. This may indicate
291-
that the variance of the Pareto smoothed importance sampling estimate
292-
is very large.""")
293-
294-
# Calculate expected values of the order statistics of the fitted Pareto
295-
S = len(r_sorted)
296-
M = S - q80
297-
z = (np.arange(M) + 0.5) / M
298-
expvals = map(lambda x: pareto.ppf(z, x[0], scale=x[2]), pareto_fit.T)
299-
300-
# Replace importance ratios with order statistics of fitted Pareto
301-
r_sorted[q80:] = np.vstack(expvals).T
302-
# Unsort ratios (within columns) before using them as weights
303-
r_new = np.array([r[np.argsort(i)]
304-
for r, i in zip(r_sorted.T, np.argsort(r.T, axis=1))]).T
305-
306-
# Truncate weights to guarantee finite variance
307-
w = np.minimum(r_new, r_new.mean(axis=0) * S**0.75)
308-
309-
loo_lppd_i = - 2. * logsumexp(log_py, axis=0, b=w / np.sum(w, axis=0))
310-
311-
loo_lppd_se = np.sqrt(len(loo_lppd_i) * np.var(loo_lppd_i))
312-
313-
loo_lppd = np.sum(loo_lppd_i)
314-
292+
loo_lppd_i = - 2 * logsumexp(lw, axis=0)
293+
loo_lppd = loo_lppd_i.sum()
294+
loo_lppd_se = (len(loo_lppd_i) * np.var(loo_lppd_i)) ** 0.5
315295
lppd = np.sum(logsumexp(log_py, axis=0, b=1. / log_py.shape[0]))
316-
317296
p_loo = lppd + (0.5 * loo_lppd)
318297

319298
if pointwise:
@@ -324,6 +303,151 @@ def loo(trace, model=None, pointwise=False, progressbar=False):
324303
return LOO_r(loo_lppd, loo_lppd_se, p_loo)
325304

326305

306+
def _psislw(lw, reff):
307+
"""Pareto smoothed importance sampling (PSIS).
308+
309+
Parameters
310+
----------
311+
lw : array
312+
Array of size (n_samples, n_observations)
313+
reff : float
314+
relative MCMC efficiency, `effective_n / n`
315+
316+
Returns
317+
-------
318+
lw_out : array
319+
Smoothed log weights
320+
kss : array
321+
Pareto tail indices
322+
"""
323+
n, m = lw.shape
324+
325+
lw_out = np.copy(lw, order='F')
326+
kss = np.empty(m)
327+
328+
# precalculate constants
329+
cutoff_ind = - int(np.ceil(min(n / 0.5, 3 * (n / reff) ** 0.5))) - 1
330+
cutoffmin = np.log(np.finfo(float).tiny)
331+
k_min = 1. / 3
332+
333+
# loop over sets of log weights
334+
for i, x in enumerate(lw_out.T):
335+
# improve numerical accuracy
336+
x -= np.max(x)
337+
# sort the array
338+
x_sort_ind = np.argsort(x)
339+
# divide log weights into body and right tail
340+
xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)
341+
342+
expxcutoff = np.exp(xcutoff)
343+
tailinds, = np.where(x > xcutoff)
344+
x2 = x[tailinds]
345+
n2 = len(x2)
346+
if n2 <= 4:
347+
# not enough tail samples for gpdfit
348+
k = np.inf
349+
else:
350+
# order of tail samples
351+
x2si = np.argsort(x2)
352+
# fit generalized Pareto distribution to the right tail samples
353+
x2 = np.exp(x2) - expxcutoff
354+
k, sigma = _gpdfit(x2[x2si])
355+
356+
if k >= k_min and not np.isinf(k):
357+
# no smoothing if short tail or GPD fit failed
358+
# compute ordered statistic for the fit
359+
sti = np.arange(0.5, n2) / n2
360+
qq = _gpinv(sti, k, sigma)
361+
qq = np.log(qq + expxcutoff)
362+
# place the smoothed tail into the output array
363+
x[tailinds[x2si]] = qq
364+
# truncate smoothed values to the largest raw weight 0
365+
x[x > 0] = 0
366+
# renormalize weights
367+
x -= logsumexp(x)
368+
# store tail index k
369+
kss[i] = k
370+
371+
return lw_out, kss
372+
373+
374+
def _gpdfit(x):
375+
"""Estimate the parameters for the Generalized Pareto Distribution (GPD)
376+
377+
Empirical Bayes estimate for the parameters of the generalized Pareto
378+
distribution given the data.
379+
380+
Parameters
381+
----------
382+
x : array
383+
sorted 1D data array
384+
385+
Returns
386+
-------
387+
k : float
388+
estimated shape parameter
389+
sigma : float
390+
estimated scale parameter
391+
"""
392+
prior_bs = 3
393+
prior_k = 10
394+
n = len(x)
395+
m = 30 + int(n**0.5)
396+
397+
bs = 1 - np.sqrt(m / (np.arange(1, m + 1, dtype=float) - 0.5))
398+
bs /= prior_bs * x[int(n/4 + 0.5) - 1]
399+
bs += 1 / x[-1]
400+
401+
ks = np.log1p(-bs[:, None] * x).mean(axis=1)
402+
L = n * (np.log(-(bs / ks)) - ks - 1)
403+
w = 1 / np.exp(L - L[:, None]).sum(axis=1)
404+
405+
# remove negligible weights
406+
dii = w >= 10 * np.finfo(float).eps
407+
if not np.all(dii):
408+
w = w[dii]
409+
bs = bs[dii]
410+
# normalise w
411+
w /= w.sum()
412+
413+
# posterior mean for b
414+
b = np.sum(bs * w)
415+
# estimate for k
416+
k = np.log1p(- b * x).mean()
417+
# add prior for k
418+
k = (n * k + prior_k * 0.5) / (n + prior_k)
419+
sigma = - k / b
420+
421+
return k, sigma
422+
423+
424+
def _gpinv(p, k, sigma):
425+
"""Inverse Generalized Pareto distribution function"""
426+
x = np.full_like(p, np.nan)
427+
if sigma <= 0:
428+
return x
429+
ok = (p > 0) & (p < 1)
430+
if np.all(ok):
431+
if np.abs(k) < np.finfo(float).eps:
432+
x = - np.log1p(-p)
433+
else:
434+
x = np.expm1(-k * np.log1p(-p)) / k
435+
x *= sigma
436+
else:
437+
if np.abs(k) < np.finfo(float).eps:
438+
x[ok] = - np.log1p(-p[ok])
439+
else:
440+
x[ok] = np.expm1(-k * np.log1p(-p[ok])) / k
441+
x *= sigma
442+
x[p == 0] = 0
443+
if k >= 0:
444+
x[p == 1] = np.inf
445+
else:
446+
x[p == 1] = - sigma / k
447+
448+
return x
449+
450+
327451
def bpic(trace, model=None):
328452
R"""Calculates Bayesian predictive information criterion n of the samples in trace from model
329453
Read more theory here - in a paper by some of the leading authorities on model selection -

0 commit comments

Comments
 (0)