Skip to content

Adds quadratic approximation #4847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)).
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
- Adds Quadratic Approximation (see [#4847](https://github.com/pymc-devs/pymc3/pull/4847)). A very fast method to approximate the posterior with a Multivariate Normal. Only works well if the posterior is unimodal and roughly symmetrical.

### Maintenance
- Remove float128 dtype support (see [#4514](https://github.com/pymc-devs/pymc3/pull/4514)).
Expand Down
1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@ def __set_compiler_flags():
from pymc3.tuning import *
from pymc3.variational import *
from pymc3.vartypes import *
from pymc3.quadratic_approximation import *
65 changes: 65 additions & 0 deletions pymc3/quadratic_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2021 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for Quadratic Approximation."""

import arviz as az
import numpy as np
import scipy

__all__ = ["quadratic_approximation"]

from pymc3.tuning import find_hessian, find_MAP


def quadratic_approximation(vars, n_samples=10_000):
"""Finds the quadratic approximation to the posterior, also known as the Laplace approximation.

NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables.
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all samples are from exactly the same distribution, the posterior quadratic approximation.
Use at your own risk.

See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background.

Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples.

Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution.

Parameters
----------
vars: list
List of variables to approximate the posterior for.
n_samples: int
How many samples to sample from the approximate posterior.

Returns
-------
arviz.InferenceData:
InferenceData with samples from the approximate posterior
scipy.stats.multivariate_normal:
Multivariate normal posterior approximation
"""
map = find_MAP(vars=vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any arguments to find_MAP that we might want to allow the user to pass? I.e. map = find_MAP(vars=vars, **map_kwargs)? Same for the Hessian. If not, that's fine (even better)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think I'd say no. It's a leaky abstraction. If someone raises an issue where they need to pass args to find_MAP we can think about how to best do it at that time. That's my 2 cents.

H = find_hessian(map, vars=vars)
cov = np.linalg.inv(H)
mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars])
posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov)
draws = posterior.rvs(n_samples)[np.newaxis, ...]
samples = {}
i = 0
for v in vars:
var_size = map[v.name].size
samples[v.name] = draws[:, :, i : i + var_size].squeeze()
i += var_size
return az.convert_to_inference_data(samples), posterior
51 changes: 51 additions & 0 deletions pymc3/tests/test_quadratic_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import arviz as az
import numpy as np

import pymc3 as pm

from pymc3.tests.helpers import SeededTest


class TestQuadraticApproximation(SeededTest):
def setup_method(self):
super().setup_method()

def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance(
self,
):
y = np.array([2642, 3503, 4358])
n = y.size

with pm.Model() as m:
logsigma = pm.Uniform("logsigma", -100, 100)
mu = pm.Uniform("mu", -10000, 10000)
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
idata, posterior = pm.quadratic_approximation([mu, logsigma])

# BDA 3 sec. 4.1 - analytical solution
bda_map = [y.mean(), np.log(y.std())]
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])

assert np.allclose(posterior.mean, bda_map)
assert np.allclose(posterior.cov, bda_cov, atol=1e-4)

def test_hdi_contains_parameters_in_linear_regression(self):
N = 100
M = 2
sigma = 0.2
X = np.random.randn(N, M)
A = np.random.randn(M)
noise = sigma * np.random.randn(N)
y = X @ A + noise

with pm.Model() as lm:
weights = pm.Normal("weights", mu=0, sigma=1, shape=M)
noise = pm.Exponential("noise", lam=1)
y_observed = pm.Normal("y_observed", mu=X @ weights, sigma=noise, observed=y)

idata, _ = pm.quadratic_approximation([weights, noise])

hdi = az.hdi(idata)
weight_hdi = hdi.weights.values
assert np.all(np.bitwise_and(weight_hdi[0, :] < A, A < weight_hdi[1, :]))
assert hdi.noise.values[0] < sigma < hdi.noise.values[1]