@@ -484,14 +484,16 @@ def _update_start_vals(a, b, model):
484
484
485
485
a .update ({k : v for k , v in b .items () if k not in a })
486
486
487
+
487
488
def sample_ppc (trace , samples = None , model = None , vars = None , size = None ,
488
- random_seed = None , progressbar = True ):
489
+ weights = None , random_seed = None , progressbar = True ):
489
490
"""Generate posterior predictive samples from a model given a trace.
490
491
491
492
Parameters
492
493
----------
493
494
trace : backend, list, or MultiTrace
494
- Trace generated from MCMC sampling
495
+ Trace generated from MCMC sampling. If a set of weights is also passed
496
+ this can be a list of traces, useful for model averaging.
495
497
samples : int
496
498
Number of posterior predictive samples to generate. Defaults to the
497
499
length of `trace`
@@ -503,16 +505,23 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
503
505
size : int
504
506
The number of random draws from the distribution specified by the
505
507
parameters in each sample of the trace.
508
+ weights: array-like
509
+ Individuals weights for each trace, useful for model averaging
510
+ random_seed : int
511
+ Seed for the random number generator.
512
+ progressbar : bool
513
+ Whether or not to display a progress bar in the command line. The
514
+ bar shows the percentage of completion, the sampling speed in
515
+ samples per second (SPS), and the estimated remaining time until
516
+ completion ("expected time of arrival"; ETA).
506
517
507
518
Returns
508
519
-------
509
520
samples : dict
510
- Dictionary with the variables as keys. The values corresponding
511
- to the posterior predictive samples.
521
+ Dictionary with the variables as keys. The values corresponding to the
522
+ posterior predictive samples. If a set of weights and a matching number
523
+ of traces are provided, then the samples will be weighted.
512
524
"""
513
- if samples is None :
514
- samples = len (trace )
515
-
516
525
if model is None :
517
526
model = modelcontext (model )
518
527
@@ -521,10 +530,33 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
521
530
522
531
seed (random_seed )
523
532
533
+ if weights is not None :
534
+ if len (trace ) != len (weights ):
535
+ raise ValueError (
536
+ 'The number of traces and weights should be the same' )
537
+
538
+ weights = np .asarray (weights )
539
+ p = weights / np .sum (weights )
540
+
541
+ min_tr = min ([len (i ) for i in trace ])
542
+
543
+ n = (min_tr * p ).astype ('int' )
544
+ # ensure n sum up to min_tr
545
+ idx = np .argmax (n )
546
+ n [idx ] = n [idx ] + min_tr - np .sum (n )
547
+
548
+ trace = np .concatenate ([np .random .choice (trace [i ], j )
549
+ for i , j in enumerate (n )])
550
+
551
+ len_trace = len (trace )
552
+
553
+ if samples is None :
554
+ samples = len_trace
555
+
556
+ indices = randint (0 , len_trace , samples )
557
+
524
558
if progressbar :
525
- indices = tqdm (randint (0 , len (trace ), samples ), total = samples )
526
- else :
527
- indices = randint (0 , len (trace ), samples )
559
+ indices = tqdm (indices , total = samples )
528
560
529
561
ppc = defaultdict (list )
530
562
for idx in indices :
0 commit comments