Skip to content

Latest commit

 

History

History
184 lines (137 loc) · 5.08 KB

empirical-approx-overview.myst.md

File metadata and controls

184 lines (137 loc) · 5.08 KB
jupytext kernelspec
text_representation
extension format_name format_version
.md
myst
0.13
display_name language name
pie
python
python3

(empirical-approx-overview)=

Empirical Approximation overview

For most models we use sampling MCMC algorithms like Metropolis or NUTS. In PyMC we got used to store traces of MCMC samples and then do analysis using them. There is a similar concept for the variational inference submodule in PyMC: Empirical. This type of approximation stores particles for the SVGD sampler. There is no difference between independent SVGD particles and MCMC samples. Empirical acts as a bridge between MCMC sampling output and full-fledged VI utils like apply_replacements or sample_node. For the interface description, see variational_api_quickstart. Here we will just focus on Emprical and give an overview of specific things for the Empirical approximation.

:::{post} Jan 13, 2023 :tags: variational inference, approximation :category: advaned, how-to :author: Maxim Kochurov, Raul Maldonado, Chris Fonnesbeck :::

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor
import seaborn as sns

from pandas import DataFrame

print(f"Running on PyMC v{pm.__version__}")
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")
np.random.seed(42)

Multimodal density

Let's recall the problem from variational_api_quickstart where we first got a NUTS trace

w = pm.floatX([0.2, 0.8])
mu = pm.floatX([-0.3, 0.5])
sd = pm.floatX([0.1, 0.1])

with pm.Model() as model:
    x = pm.NormalMixture("x", w=w, mu=mu, sigma=sd)
    trace = pm.sample(50_000, return_inferencedata=False)
with model:
    idata = pm.to_inference_data(trace)
az.plot_trace(idata);

Great. First having a trace we can create Empirical approx

print(pm.Empirical.__doc__)
with model:
    approx = pm.Empirical(trace)
approx

This type of approximation has it's own underlying storage for samples that is pytensor.shared itself

approx.histogram
approx.histogram.get_value()[:10]
approx.histogram.get_value().shape

It has exactly the same number of samples that you had in trace before. In our particular case it is 50k. Another thing to notice is that if you have multitrace with more than one chain you'll get much more samples stored at once. We flatten all the trace for creating Empirical.

This histogram is about how we store samples. The structure is pretty simple: (n_samples, n_dim) The order of these variables is stored internally in the class and in most cases will not be needed for end user

approx.ordering

Sampling from posterior is done uniformly with replacements. Call approx.sample(1000) and you'll get again the trace but the order is not determined. There is no way now to reconstruct the underlying trace again with approx.sample.

new_trace = approx.sample(50000)

After sampling function is compiled sampling bacomes really fast

az.plot_trace(new_trace);

You see there is no order any more but reconstructed density is the same.

2d density

mu = pm.floatX([0.0, 0.0])
cov = pm.floatX([[1, 0.5], [0.5, 1.0]])
with pm.Model() as model:
    pm.MvNormal("x", mu=mu, cov=cov, shape=2)
    trace = pm.sample(1000, return_inferencedata=False)
    idata = pm.to_inference_data(trace)
with model:
    approx = pm.Empirical(trace)
az.plot_trace(approx.sample(10000));
kdeViz_df = DataFrame(
    data=approx.sample(1000).posterior["x"].squeeze(),
    columns=["First Dimension", "Second Dimension"],
)

sns.kdeplot(data=kdeViz_df, x="First Dimension", y="Second Dimension")
plt.show()

Previously we had a trace_cov function

with model:
    print(pm.trace_cov(trace))

Now we can estimate the same covariance using Empirical

print(approx.cov)

That's a tensor object, which we need to evaluate.

print(approx.cov.eval())

Estimations are very close and differ due to precision error. We can get the mean in the same way

print(approx.mean.eval())

Authors

Watermark

%load_ext watermark
%watermark -n -u -v -iv -w

:::{include} ../page_footer.md :::