Skip to content

Commit a6a4158

Browse files
committed
Fix front matter.
1 parent 037d6bd commit a6a4158

File tree

2 files changed

+108
-8
lines changed

2 files changed

+108
-8
lines changed

examples/variational_inference/pathfinder.ipynb

+24-8
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.13.7
8+
kernelspec:
9+
display_name: pymc4
10+
language: python
11+
name: pymc4
12+
---
13+
14+
(pathfinder)=
15+
16+
# Pathfinder Variational Inference
17+
18+
+++
19+
20+
:::{post} Sept 30, 2022
21+
:tags: variational inference, jax
22+
:category: advanced
23+
:author: Thomas Wiecki
24+
:::
25+
26+
+++
27+
28+
Pathfinder is a variational inference algorithm that produces samples from the posterior of a Bayesian model. It compares favorably to the widely used ADVI algorithm. On large problems, it should scale better than most MCMC algorithms, including dynamic HMC (i.e. NUTS), at the cost of a more biased estimate of the posterior. For details on the algorithm, see the [arxiv preprint](https://arxiv.org/abs/2108.03782).
29+
30+
This algorithm is [implemented](https://github.com/blackjax-devs/blackjax/pull/194) in [BlackJAX](https://github.com/blackjax-devs/blackjax), a library of inference algorithms for [JAX](https://github.com/google/jax). Through PyMC's JAX-backend (through [aesara](https://github.com/aesara-devs/aesara)) we can run BlackJAX's pathfinder on any PyMC model with some simple wrapper code.
31+
32+
This wrapper code is implemented in [pymc-experimental](https://github.com/pymc-devs/pymc-experimental/). This tutorial shows how to run Pathfinder on your PyMC model.
33+
34+
You first need to install `pymc-experimental`:
35+
36+
`pip install git+https://github.com/pymc-devs/pymc-experimental`
37+
38+
```{code-cell} ipython3
39+
import arviz as az
40+
import numpy as np
41+
import pymc as pm
42+
43+
# Import pymc_experimental
44+
import pymc_experimental as pmx
45+
46+
print(f"Running on PyMC v{pm.__version__}")
47+
```
48+
49+
First, define your PyMC model. Here, we use the 8-schools model.
50+
51+
```{code-cell} ipython3
52+
# Data of the Eight Schools Model
53+
J = 8
54+
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
55+
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
56+
57+
with pm.Model() as model:
58+
mu = pm.Normal("mu", mu=0.0, sigma=10.0)
59+
tau = pm.HalfCauchy("tau", 5.0)
60+
61+
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
62+
theta_1 = mu + tau * theta
63+
obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y)
64+
```
65+
66+
Next, we call `pmx.fit()` and pass in the algorithm we want it to use.
67+
68+
```{code-cell} ipython3
69+
with model:
70+
idata = pmx.fit(method="pathfinder")
71+
```
72+
73+
Just like `pymc.sample()`, this returns an idata with samples from the posterior. Note that because these samples do not come from an MCMC chain, convergence can not be assessed in the regular way.
74+
75+
```{code-cell} ipython3
76+
az.plot_trace(idata);
77+
```
78+
79+
## Watermark
80+
81+
```{code-cell} ipython3
82+
%load_ext watermark
83+
%watermark -n -u -v -iv -w -p aesara,xarray
84+
```

0 commit comments

Comments
 (0)