|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | + jupytext_version: 1.13.7 |
| 8 | +kernelspec: |
| 9 | + display_name: pymc4 |
| 10 | + language: python |
| 11 | + name: pymc4 |
| 12 | +--- |
| 13 | + |
| 14 | +(pathfinder)= |
| 15 | + |
| 16 | +# Pathfinder Variational Inference |
| 17 | + |
| 18 | +:::{post} Sept 30, 2022 |
| 19 | +:tags: variational inference, jax |
| 20 | +:category: advanced, how-to |
| 21 | +:author: Thomas Wiecki |
| 22 | +::: |
| 23 | + |
| 24 | ++++ |
| 25 | + |
| 26 | +Pathfinder {cite:p}`zhang2021pathfinder` is a variational inference algorithm that produces samples from the posterior of a Bayesian model. It compares favorably to the widely used ADVI algorithm. On large problems, it should scale better than most MCMC algorithms, including dynamic HMC (i.e. NUTS), at the cost of a more biased estimate of the posterior. For details on the algorithm, see the [arxiv preprint](https://arxiv.org/abs/2108.03782). |
| 27 | + |
| 28 | +This algorithm is [implemented](https://github.com/blackjax-devs/blackjax/pull/194) in [BlackJAX](https://github.com/blackjax-devs/blackjax), a library of inference algorithms for [JAX](https://github.com/google/jax). Through PyMC's JAX-backend (through [aesara](https://github.com/aesara-devs/aesara)) we can run BlackJAX's pathfinder on any PyMC model with some simple wrapper code. |
| 29 | + |
| 30 | +This wrapper code is implemented in [pymcx](https://github.com/pymc-devs/pymcx/). This tutorial shows how to run Pathfinder on your PyMC model. |
| 31 | + |
| 32 | +You first need to install `pymcx`: |
| 33 | + |
| 34 | +`pip install git+https://github.com/pymc-devs/pymcx` |
| 35 | + |
| 36 | +```{code-cell} ipython3 |
| 37 | +import arviz as az |
| 38 | +import numpy as np |
| 39 | +import pymc as pm |
| 40 | +import pymcx as pmx |
| 41 | +
|
| 42 | +print(f"Running on PyMC v{pm.__version__}") |
| 43 | +``` |
| 44 | + |
| 45 | +First, define your PyMC model. Here, we use the 8-schools model. |
| 46 | + |
| 47 | +```{code-cell} ipython3 |
| 48 | +# Data of the Eight Schools Model |
| 49 | +J = 8 |
| 50 | +y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) |
| 51 | +sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) |
| 52 | +
|
| 53 | +with pm.Model() as model: |
| 54 | + mu = pm.Normal("mu", mu=0.0, sigma=10.0) |
| 55 | + tau = pm.HalfCauchy("tau", 5.0) |
| 56 | +
|
| 57 | + theta = pm.Normal("theta", mu=0, sigma=1, shape=J) |
| 58 | + theta_1 = mu + tau * theta |
| 59 | + obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) |
| 60 | +``` |
| 61 | + |
| 62 | +Next, we call `pmx.fit()` and pass in the algorithm we want it to use. |
| 63 | + |
| 64 | +```{code-cell} ipython3 |
| 65 | +with model: |
| 66 | + idata = pmx.fit(method="pathfinder") |
| 67 | +``` |
| 68 | + |
| 69 | +Just like `pymc.sample()`, this returns an idata with samples from the posterior. Note that because these samples do not come from an MCMC chain, convergence can not be assessed in the regular way. |
| 70 | + |
| 71 | +```{code-cell} ipython3 |
| 72 | +az.plot_trace(idata); |
| 73 | +``` |
| 74 | + |
| 75 | +## References |
| 76 | + |
| 77 | +:::{bibliography} |
| 78 | +:filter: docname in docnames |
| 79 | +::: |
| 80 | + |
| 81 | ++++ |
| 82 | + |
| 83 | +## Authors |
| 84 | + |
| 85 | +* Authored by Thomas Wiecki on Oct 11 2022 ([pymc-examples#429](https://github.com/pymc-devs/pymc-examples/pull/429)) |
| 86 | + |
| 87 | ++++ |
| 88 | + |
| 89 | +## Watermark |
| 90 | + |
| 91 | +```{code-cell} ipython3 |
| 92 | +%load_ext watermark |
| 93 | +%watermark -n -u -v -iv -w -p aesara,xarray |
| 94 | +``` |
| 95 | + |
| 96 | +:::{include} ../page_footer.md |
| 97 | +::: |
0 commit comments