Skip to content

Update VI notebooks to v5 #497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 15, 2023
224 changes: 144 additions & 80 deletions examples/variational_inference/GLM-hierarchical-advi-minibatch.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
format_name: myst
format_version: 0.13
kernelspec:
display_name: Python 3
display_name: pie
language: python
name: python3
---
Expand All @@ -22,22 +22,22 @@ kernelspec:
Unlike Gaussian mixture models, (hierarchical) regression models have independent variables. These variables affect the likelihood function, but are not random variables. When using mini-batch, we should take care of that.

```{code-cell} ipython3
%env THEANO_FLAGS=device=cpu, floatX=float32, warn_float64=ignore
%env PYTENSOR_FLAGS=device=cpu, floatX=float32, warn_float64=ignore

import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc as pm
import pytensor
import pytensor.tensor as pt
import seaborn as sns
import theano
import theano.tensor as tt

from scipy import stats

print(f"Running on PyMC3 v{pm.__version__}")
print(f"Running on PyMC v{pm.__version__}")
```

```{code-cell} ipython3
Expand Down Expand Up @@ -67,9 +67,9 @@ coords = {"counties": data.county.unique()}
Here, `log_radon_idx_t` is a dependent variable, while `floor_idx_t` and `county_idx_t` determine the independent variable.

```{code-cell} ipython3
log_radon_idx_t = pm.Minibatch(log_radon_idx, 100)
floor_idx_t = pm.Minibatch(floor_idx, 100)
county_idx_t = pm.Minibatch(county_idx, 100)
log_radon_idx_t = pm.Minibatch(log_radon_idx, batch_size=100)
floor_idx_t = pm.Minibatch(floor_idx, batch_size=100)
county_idx_t = pm.Minibatch(county_idx, batch_size=100)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -125,8 +125,9 @@ Then, run ADVI with mini-batch.

```{code-cell} ipython3
with hierarchical_model:
approx = pm.fit(100000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
idata_advi = az.from_pymc3(approx.sample(500))
approx = pm.fit(100_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

idata_advi = approx.sample(500)
```

Check the trace of ELBO and compare the result with MCMC.
Expand All @@ -135,6 +136,20 @@ Check the trace of ELBO and compare the result with MCMC.
plt.plot(approx.hist);
```

We can extract the covariance matrix from the mean field approximation and use it as the scaling matrix for the NUTS algorithm.

```{code-cell} ipython3
scaling = approx.cov.eval()
```

Also, we can generate samples (one for each chain) to use as the starting points for the sampler.

```{code-cell} ipython3
n_chains = 4
sample = approx.sample(return_inferencedata=False, size=n_chains)
start_dict = list(sample[i] for i in range(n_chains))
```

```{code-cell} ipython3
# Inference button (TM)!
with pm.Model(coords=coords):
Expand All @@ -155,15 +170,15 @@ with pm.Model(coords=coords):
radon_like = pm.Normal("radon_like", mu=radon_est, sigma=eps, observed=log_radon_idx)

# essentially, this is what init='advi' does
step = pm.NUTS(scaling=approx.cov.eval(), is_cov=True)
hierarchical_trace = pm.sample(
2000, step, start=approx.sample()[0], progressbar=True, return_inferencedata=True
)
step = pm.NUTS(scaling=scaling, is_cov=True)
hierarchical_trace = pm.sample(draws=2000, step=step, chains=n_chains, initvals=start_dict)
```

```{code-cell} ipython3
az.plot_density(
[idata_advi, hierarchical_trace], var_names=["~alpha", "~beta"], data_labels=["ADVI", "NUTS"]
data=[idata_advi, hierarchical_trace],
var_names=["~alpha", "~beta"],
data_labels=["ADVI", "NUTS"],
);
```

Expand All @@ -173,3 +188,6 @@ az.plot_density(
%load_ext watermark
%watermark -n -u -v -iv -w -p xarray
```

:::{include} ../page_footer.md
:::
Loading