Skip to content

Adds quadratic approximation #4847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ def __set_compiler_flags():
from pymc3.tuning import *
from pymc3.variational import *
from pymc3.vartypes import *
from pymc3.quadratic_approximation import *
68 changes: 68 additions & 0 deletions pymc3/quadratic_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for Quadratic Approximation."""

import numpy as np
import scipy
import arviz as az

__all__ = [
"quadratic_approximation"
]

from pymc3.tuning import find_MAP, find_hessian


def quadratic_approximation(vars, n_chains=2, n_samples=10_000):
""" Finds the quadratic approximation to the posterior, also known as the laplace approximation.

NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables.
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation.
Use at your own risk.

See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background.

Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples, so the notion of chains is meaningless, and is only included for downstream compatibility with Arviz.

Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution.

Parameters
----------
vars: list
List of variables to approximate the posterior for.
n_chains: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this kwarg and should fix it to 1.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, Arviz will complain if there aren't at least 2 chains.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you share an example? I think you'll get trouble for multivariate distributions, where arviz will interpret the leading dimension as the chains, but I would solve that as follows:

data = {'a': np.ones(10), 'b': np.ones((100, 3))}

az.convert_to_inference_data(data) # warning, misinterprets `b` as having 100 chains

az.convert_to_inference_data({k: v[np.newaxis, ...] for k, v in data.items()}) # Good, explicitly sets number of chains to 1.

But maybe I'm misunderstanding the issue!

I'm against including chains since it gives the impression that you might want to run diagnostics on the quality of the returned samples.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I guess that's not too bad. All the plots seem to work. I just saw that warning and figured nothing would work. I'll remove the chains. I fully agree. The less impression we can give that running diagnostics the better. A little warning and a NaN might actually be good :)

How many chains to simulate.
n_samples: int
How many samples to sample from the approximate posterior for each chain.

Returns
-------
(arviz.InferenceData, scipy.stats.multivariate_normal):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also list these items individually, you can look at other doc strings for examples for multiple return elements.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find any examples. I tried a new format. Please provide a reference if it's not correct :)

InferenceData with samples from the approximate posterior, multivariate normal posterior approximation

"""
map = find_MAP(vars=vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any arguments to find_MAP that we might want to allow the user to pass? I.e. map = find_MAP(vars=vars, **map_kwargs)? Same for the Hessian. If not, that's fine (even better)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think I'd say no. It's a leaky abstraction. If someone raises an issue where they need to pass args to find_MAP we can think about how to best do it at that time. That's my 2 cents.

H = find_hessian(map, vars=vars)
cov = np.linalg.inv(H)
mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars])
posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov)
draws = posterior.rvs((n_chains, n_samples))
samples = {}
i = 0
for v in vars:
var_size = map[v.name].size
samples[v.name] = draws[:, :, i:i + var_size].squeeze()
i += var_size
return az.convert_to_inference_data(samples), posterior
53 changes: 53 additions & 0 deletions pymc3/tests/test_quadratic_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import pymc3 as pm
import arviz as az

from pymc3.tests.helpers import SeededTest


class TestQuadraticApproximation(SeededTest):
def setup_method(self):
super().setup_method()

def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance():
y = np.array([2642, 3503, 4358])
n = y.size

with pm.Model() as m:
logsigma = pm.Uniform("logsigma", -100, 100)
mu = pm.Uniform("mu", -10000, 10000)
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
idata, posterior = pm.quadratic_approximation([mu, logsigma])

# BDA 3 sec. 4.1 - analytical solution
bda_map = [y.mean(), np.log(y.std())]
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])

assert np.allclose(posterior.mean, bda_map)
assert np.allclose(posterior.cov, bda_cov, atol=1e-4)

def test_hdi_contains_parameters_in_linear_regression():
N = 100
M = 2
sigma = 0.2
X = np.random.randn(N, M)
A = np.random.randn(M)
noise = sigma * np.random.randn(N)
y = X @ A + noise

with pm.Model() as lm:
weights = pm.Normal("weights", mu=0, sigma=1, shape=M)
noise = pm.Exponential("noise", lam=1)
y_observed = pm.Normal(
"y_observed",
mu=X @ weights,
sigma=noise,
observed=y
)

idata, _ = pm.quadratic_approximation([weights, noise])

hdi = az.hdi(idata)
weight_hdi = hdi.weights.values
assert np.all(np.bitwise_and(weight_hdi[0, :] < A, A < weight_hdi[1, :]))
assert hdi.noise.values[0] < sigma < hdi.noise.values[1]