Skip to content

Commit 8ef4dc5

Browse files
committed
add weighted ppc
1 parent 8c81624 commit 8ef4dc5

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

pymc3/sampling.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -484,14 +484,16 @@ def _update_start_vals(a, b, model):
484484

485485
a.update({k: v for k, v in b.items() if k not in a})
486486

487+
487488
def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
488-
random_seed=None, progressbar=True):
489+
weights=None, random_seed=None, progressbar=True):
489490
"""Generate posterior predictive samples from a model given a trace.
490491
491492
Parameters
492493
----------
493494
trace : backend, list, or MultiTrace
494-
Trace generated from MCMC sampling
495+
Trace generated from MCMC sampling. If a set of weights is also passed
496+
this can be a list of traces, useful for model averaging.
495497
samples : int
496498
Number of posterior predictive samples to generate. Defaults to the
497499
length of `trace`
@@ -503,16 +505,23 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
503505
size : int
504506
The number of random draws from the distribution specified by the
505507
parameters in each sample of the trace.
508+
weights: array-like
509+
Individuals weights for each trace, useful for model averaging
510+
random_seed : int
511+
Seed for the random number generator.
512+
progressbar : bool
513+
Whether or not to display a progress bar in the command line. The
514+
bar shows the percentage of completion, the sampling speed in
515+
samples per second (SPS), and the estimated remaining time until
516+
completion ("expected time of arrival"; ETA).
506517
507518
Returns
508519
-------
509520
samples : dict
510-
Dictionary with the variables as keys. The values corresponding
511-
to the posterior predictive samples.
521+
Dictionary with the variables as keys. The values corresponding to the
522+
posterior predictive samples. If a set of weights and a matching number
523+
of traces are provided, then the samples will be weighted.
512524
"""
513-
if samples is None:
514-
samples = len(trace)
515-
516525
if model is None:
517526
model = modelcontext(model)
518527

@@ -521,10 +530,33 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
521530

522531
seed(random_seed)
523532

533+
if weights is not None:
534+
if len(trace) != len(weights):
535+
raise ValueError(
536+
'The number of traces and weights should be the same')
537+
538+
weights = np.asarray(weights)
539+
p = weights / np.sum(weights)
540+
541+
min_tr = min([len(i) for i in trace])
542+
543+
n = (min_tr * p).astype('int')
544+
# ensure n sum up to min_tr
545+
idx = np.argmax(n)
546+
n[idx] = n[idx] + min_tr - np.sum(n)
547+
548+
trace = np.concatenate([np.random.choice(trace[i], j)
549+
for i, j in enumerate(n)])
550+
551+
len_trace = len(trace)
552+
553+
if samples is None:
554+
samples = len_trace
555+
556+
indices = randint(0, len_trace, samples)
557+
524558
if progressbar:
525-
indices = tqdm(randint(0, len(trace), samples), total=samples)
526-
else:
527-
indices = randint(0, len(trace), samples)
559+
indices = tqdm(indices, total=samples)
528560

529561
ppc = defaultdict(list)
530562
for idx in indices:

0 commit comments

Comments
 (0)