Skip to content

Commit 547e2e0

Browse files
ricardoV94AlexAndorra
authored andcommitted
Add examples for advanced uses of sample_posterior_predictive
1 parent 35ccca5 commit 547e2e0

File tree

1 file changed

+224
-14
lines changed

1 file changed

+224
-14
lines changed

pymc/sampling/forward.py

Lines changed: 224 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -448,35 +448,48 @@ def sample_posterior_predictive(
448448
idata_kwargs: Optional[dict] = None,
449449
compile_kwargs: Optional[dict] = None,
450450
) -> Union[InferenceData, dict[str, np.ndarray]]:
451-
"""Generate posterior predictive samples from a model given a trace.
451+
"""Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`.
452+
453+
This method can be used to perform different kinds of model predictions, including posterior predictive checks.
454+
455+
The matching of unobserved model variables, and posterior samples in the `trace` is made based on the variable
456+
names. Therefore, a different model than the one used for posterior sampling may be used for posterior predictive
457+
sampling, as long as the variables whose posterior we want to condition on have the same name, and compatible shape
458+
and coordinates.
459+
452460
453461
Parameters
454462
----------
455463
trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
456-
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
457-
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
464+
Trace generated from MCMC sampling, or a list of dicts (eg. points or from :func:`~pymc.find_MAP`),
465+
or :class:`xarray.Dataset` (eg. InferenceData.posterior or InferenceData.prior)
458466
model : Model (optional if in ``with`` context)
459467
Model to be used to generate the posterior predictive samples. It will
460-
generally be the model used to generate the ``trace``, but it doesn't need to be.
461-
var_names : Iterable[str]
468+
generally be the model used to generate the `trace`, but it doesn't need to be.
469+
var_names : Iterable[str], optional
462470
Names of variables for which to compute the posterior predictive samples.
471+
By default, only observed variables are sampled.
472+
See the example below for what happens when this argument is customized.
463473
sample_dims : list of str, optional
464474
Dimensions over which to loop and generate posterior predictive samples.
465-
When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
475+
When ``sample_dims`` is ``None`` (default) both "chain" and "draw" are considered sample
466476
dimensions. Only taken into account when `trace` is InferenceData or Dataset.
467477
random_seed : int, RandomState or Generator, optional
468478
Seed for the random number generator.
469479
progressbar : bool
470-
Whether or not to display a progress bar in the command line. The bar shows the percentage
480+
Whether to display a progress bar in the command line. The bar shows the percentage
471481
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
472482
time until completion ("expected time of arrival"; ETA).
473483
return_inferencedata : bool, default True
474484
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
475485
extend_inferencedata : bool, default False
476486
Whether to automatically use :meth:`arviz.InferenceData.extend` to add the posterior predictive samples to
477-
``trace`` or not. If True, ``trace`` is modified inplace but still returned.
487+
`trace` or not. If True, `trace` is modified inplace but still returned.
478488
predictions : bool, default False
479-
Flag used to set the location of posterior predictive samples within the returned ``arviz.InferenceData`` object. If False, assumes samples are generated based on the fitting data to be used for posterior predictive checks, and samples are stored in the ``posterior_predictive``. If True, assumes samples are generated based on out-of-sample data as predictions, and samples are stored in the ``predictions`` group.
489+
Flag used to set the location of posterior predictive samples within the returned ``arviz.InferenceData`` object.
490+
If False, assumes samples are generated based on the fitting data to be used for posterior predictive checks,
491+
and samples are stored in the ``posterior_predictive``. If True, assumes samples are generated based on
492+
out-of-sample data as predictions, and samples are stored in the ``predictions`` group.
480493
idata_kwargs : dict, optional
481494
Keyword arguments for :func:`pymc.to_inference_data` if ``predictions=False`` or to
482495
:func:`pymc.predictions_to_inference_data` otherwise.
@@ -489,24 +502,221 @@ def sample_posterior_predictive(
489502
An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or
490503
a dictionary with variable names as keys, and samples as numpy arrays.
491504
505+
492506
Examples
493507
--------
494-
Thin a sampled inferencedata by keeping 1 out of every 5 draws
495-
before passing it to sample_posterior_predictive
508+
Posterior predictive checks and predictions
509+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
510+
511+
The most common use of `sample_posterior_predictive` is to perform posterior predictive checks (in-sample predictions)
512+
and new model predictions (out-of-sample predictions).
513+
514+
.. code:: python
515+
516+
import pymc as pm
517+
518+
with pm.Model(coords_mutable={"trial": [0, 1, 2]}) as model:
519+
x = pm.MutableData("x", [-1, 0, 1], dims=["trial"])
520+
beta = pm.Normal("beta")
521+
noise = pm.HalfNormal("noise")
522+
y = pm.Normal("y", mu=x * beta, sigma=noise, observed=[-2, 0, 3], dims=["trial"])
523+
524+
idata = pm.sample()
525+
# in-sample predictions
526+
posterior_predictive = pm.sample_posterior_predictive(idata).posterior_predictive
527+
528+
with model:
529+
pm.set_data({"x": [-2, 2]}, coords={"trial": [3, 4]})
530+
# out-of-sample predictions
531+
predictions = pm.sample_posterior_predictive(idata, predictions=True).predictions
532+
533+
534+
Using different models
535+
^^^^^^^^^^^^^^^^^^^^^^
536+
537+
It's common to use the same model for posterior and posterior predictive sampling, but this is not required.
538+
The matching between unobserved model variables and posterior samples is based on the name alone.
539+
540+
For the last example we could have created a new predictions model. Note that we have to specify
541+
`var_names` explicitly, because the newly defined `y` was not given any observations:
542+
543+
.. code:: python
544+
545+
with pm.Model(coords_mutable={"trial": [3, 4]}) as predictions_model:
546+
x = pm.MutableData("x", [-2, 2], dims=["trial"])
547+
beta = pm.Normal("beta")
548+
noise = pm.HalfNormal("noise")
549+
y = pm.Normal("y", mu=x * beta, sigma=noise, dims=["trial"])
550+
551+
predictions = pm.sample_posterior_predictive(idata, var_names=["y"], predictions=True).predictions
552+
553+
554+
The new model may even have a different structure and unobserved variables that don't exist in the trace.
555+
These variables will also be forward sampled. In the following example we added a new ``extra_noise``
556+
variable between the inferred posterior ``noise`` and the new StudentT observational distribution ``y``:
557+
558+
.. code:: python
559+
560+
with pm.Model(coords_mutable={"trial": [3, 4]}) as distinct_predictions_model:
561+
x = pm.MutableData("x", [-2, 2], dims=["trial"])
562+
beta = pm.Normal("beta")
563+
noise = pm.HalfNormal("noise")
564+
extra_noise = pm.HalfNormal("extra_noise", sigma=noise)
565+
y = pm.StudentT("y", nu=4, mu=x * beta, sigma=extra_noise, dims=["trial"])
566+
567+
predictions = pm.sample_posterior_predictive(idata, var_names=["y"], predictions=True).predictions
568+
569+
570+
For more about out-of-model predictions, see this `blog post <https://www.pymc-labs.com/blog-posts/out-of-model-predictions-with-pymc/>`_.
571+
572+
The behavior of `var_names`
573+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
574+
575+
The function returns forward samples for any variable included in `var_names`,
576+
conditioned on the values of other random variables found in the trace.
577+
578+
To ensure the samples are internally consistent, any random variable that depends
579+
on another random variable that is being sampled is itself sampled, even if
580+
this variable is present in the trace and was not included in `var_names`.
581+
The final list of variables being sampled is shown in the log output.
582+
583+
Note that if a random variable has no dependency on other random variables,
584+
these forward samples are equivalent to their prior samples.
585+
Likewise, if all random variables are being sampled, the behavior of this function
586+
is equivalent to that of :func:`~pymc.sample_prior_predictive`.
587+
588+
.. warning:: A random variable included in `var_names` will never be copied from the posterior group. It will always be sampled as described above. If you want, you can copy manually via ``idata.posterior_predictive["var_name"] = idata.posterior["var_name"]``.
589+
590+
591+
The following code block explores how the behavior changes with different `var_names`:
592+
593+
.. code:: python
594+
595+
from logging import getLogger
596+
import pymc as pm
597+
598+
# Some environments like google colab suppress
599+
# the default logging output of PyMC
600+
getLogger("pymc").setLevel("INFO")
601+
602+
kwargs = {"progressbar": False, "random_seed": 0}
603+
604+
with pm.Model() as model:
605+
x = pm.Normal("x")
606+
y = pm.Normal("y")
607+
z = pm.Normal("z", x + y**2)
608+
det = pm.Deterministic("det", pm.math.exp(z))
609+
obs = pm.Normal("obs", det, 1, observed=[20])
610+
611+
idata = pm.sample(tune=10, draws=10, chains=2, **kwargs)
612+
613+
Default behavior. Generate samples of ``obs``, conditioned on the posterior samples of ``z`` found in the trace.
614+
These are often referred to as posterior predictive samples in the literature:
615+
616+
.. code:: python
617+
618+
with model:
619+
pm.sample_posterior_predictive(idata, var_names=["obs"], **kwargs)
620+
# Sampling: [obs]
621+
622+
Re-compute the deterministic variable ``det``, conditioned on the posterior samples of ``z`` found in the trace:
623+
624+
.. code :: python
625+
626+
pm.sample_posterior_predictive(idata, var_names=["det"], **kwargs)
627+
# Sampling: []
628+
629+
Generate samples of ``z`` and ``det``, conditioned on the posterior samples of ``x`` and ``y`` found in the trace.
630+
631+
.. code :: python
632+
633+
with model:
634+
pm.sample_posterior_predictive(idata, var_names=["z", "det"], **kwargs)
635+
# Sampling: [z]
636+
637+
638+
Generate samples of ``y``, ``z`` and ``det``, conditioned on the posterior samples of ``x`` found in the trace.
639+
640+
Note: The samples of ``y`` are equivalent to its prior, since it does not depend on any other variables.
641+
In contrast, the samples of ``z`` and ``det`` depend on the new samples of ``y`` and the posterior samples of
642+
``x`` found in the trace.
643+
644+
.. code :: python
645+
646+
with model:
647+
pm.sample_posterior_predictive(idata, var_names=["y", "z", "det"], **kwargs)
648+
# Sampling: [y, z]
649+
650+
651+
Same as before, except ``z`` is not stored in the returned trace.
652+
For computing ``det`` we still have to sample ``z`` as it depends on ``y``, which is also being sampled.
653+
654+
.. code :: python
655+
656+
with model:
657+
pm.sample_posterior_predictive(idata, var_names=["y", "det"], **kwargs)
658+
# Sampling: [y, z]
659+
660+
Every random variable is sampled. This is equivalent to calling :func:`~pymc.sample_prior_predictive`
661+
662+
.. code :: python
663+
664+
with model:
665+
pm.sample_posterior_predictive(idata, var_names=["x", "y", "z", "det", "obs"], **kwargs)
666+
# Sampling: [x, y, z, obs]
667+
668+
669+
Note that "sampling" a :func:`~pymc.Deterministic` does not force random variables
670+
that depend on this quantity to be sampled too. In the following example ``z`` will not
671+
be resampled even though it depends on ``det_xy``:
672+
673+
.. code :: python
674+
675+
with pm.Model() as model:
676+
x = pm.Normal("x")
677+
y = pm.Normal("y")
678+
det_xy = pm.Deterministic("det_xy", x + y**2)
679+
z = pm.Normal("z", det_xy)
680+
det_z = pm.Deterministic("det_z", pm.math.exp(z))
681+
obs = pm.Normal("obs", det_z, 1, observed=[20])
682+
683+
idata = pm.sample(tune=10, draws=10, chains=2, **kwargs)
684+
685+
pm.sample_posterior_predictive(idata, var_names=["det_xy", "det_z"], **kwargs)
686+
# Sampling: []
687+
688+
689+
Controlling the number of samples
690+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
691+
692+
You can manipulate the InferenceData to control the number of samples
693+
694+
.. code:: python
695+
696+
import pymc as pm
697+
698+
with pm.Model() as model:
699+
...
700+
idata = pm.sample()
701+
702+
Generate 1 posterior predictive sample for every 5 posterior samples.
496703
497704
.. code:: python
498705
499706
thinned_idata = idata.sel(draw=slice(None, None, 5))
500707
with model:
501-
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
708+
idata.extend(pm.sample_posterior_predictive(thinned_idata))
709+
502710
503-
Generate 5 posterior predictive samples per posterior sample.
711+
Generate 5 posterior predictive samples for every posterior sample.
504712
505713
.. code:: python
506714
507715
expanded_data = idata.posterior.expand_dims(pred_id=5)
508716
with model:
509-
idata.extend(pymc.sample_posterior_predictive(expanded_data))
717+
idata.extend(pm.sample_posterior_predictive(expanded_data))
718+
719+
510720
"""
511721

512722
_trace: Union[MultiTrace, PointList]

0 commit comments

Comments
 (0)