From 78366d12bd9df09c2bd3967c704b3823173e7e9c Mon Sep 17 00:00:00 2001 From: rasmusbergpalm Date: Thu, 8 Jul 2021 13:38:11 +0200 Subject: [PATCH 1/6] Adds quadratic approximation The quadratic approximation is an extremely fast way to obtain an estimate of the posterior. It only works if the posterior is unimodal and approximately symmetric. The implementation finds the quadratic approximation and then samples from it to return an arviz.InferenceData object. Finding the exact (approximate) posterior and then sampling from it might seem counter intuitive, but it's done to be compatible with the rest of the codebase and Arviz functionality. The exact (approximate) posterior is also returned as a scipy.stats.multivariate_normal distribution. --- pymc3/__init__.py | 1 + pymc3/quadratic_approximation.py | 68 +++++++++++++++++++++ pymc3/tests/test_quadratic_approximation.py | 53 ++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 pymc3/quadratic_approximation.py create mode 100644 pymc3/tests/test_quadratic_approximation.py diff --git a/pymc3/__init__.py b/pymc3/__init__.py index d1a87f08cd..d4ddf3dba4 100644 --- a/pymc3/__init__.py +++ b/pymc3/__init__.py @@ -76,3 +76,4 @@ def __set_compiler_flags(): from pymc3.tuning import * from pymc3.variational import * from pymc3.vartypes import * +from pymc3.quadratic_approximation import * diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py new file mode 100644 index 0000000000..5273854141 --- /dev/null +++ b/pymc3/quadratic_approximation.py @@ -0,0 +1,68 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for Quadratic Approximation.""" + +import numpy as np +import scipy +import arviz as az + +__all__ = [ + "quadratic_approximation" +] + +from pymc3.tuning import find_MAP, find_hessian + + +def quadratic_approximation(vars, n_chains=2, n_samples=10_000): + """ Finds the quadratic approximation to the posterior, also known as the laplace approximation. + + NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables. + The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation. + Use at your own risk. + + See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background. + + Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples, so the notion of chains is meaningless, and is only included for downstream compatibility with Arviz. + + Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution. + + Parameters + ---------- + vars: list + List of variables to approximate the posterior for. + n_chains: int + How many chains to simulate. + n_samples: int + How many samples to sample from the approximate posterior for each chain. + + Returns + ------- + (arviz.InferenceData, scipy.stats.multivariate_normal): + InferenceData with samples from the approximate posterior, multivariate normal posterior approximation + + """ + map = find_MAP(vars=vars) + H = find_hessian(map, vars=vars) + cov = np.linalg.inv(H) + mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) + posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov) + draws = posterior.rvs((n_chains, n_samples)) + samples = {} + i = 0 + for v in vars: + var_size = map[v.name].size + samples[v.name] = draws[:, :, i:i + var_size].squeeze() + i += var_size + return az.convert_to_inference_data(samples), posterior diff --git a/pymc3/tests/test_quadratic_approximation.py b/pymc3/tests/test_quadratic_approximation.py new file mode 100644 index 0000000000..14951cb21b --- /dev/null +++ b/pymc3/tests/test_quadratic_approximation.py @@ -0,0 +1,53 @@ +import numpy as np +import pymc3 as pm +import arviz as az + +from pymc3.tests.helpers import SeededTest + + +class TestQuadraticApproximation(SeededTest): + def setup_method(self): + super().setup_method() + + def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance(): + y = np.array([2642, 3503, 4358]) + n = y.size + + with pm.Model() as m: + logsigma = pm.Uniform("logsigma", -100, 100) + mu = pm.Uniform("mu", -10000, 10000) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + idata, posterior = pm.quadratic_approximation([mu, logsigma]) + + # BDA 3 sec. 4.1 - analytical solution + bda_map = [y.mean(), np.log(y.std())] + bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) + + assert np.allclose(posterior.mean, bda_map) + assert np.allclose(posterior.cov, bda_cov, atol=1e-4) + + def test_hdi_contains_parameters_in_linear_regression(): + N = 100 + M = 2 + sigma = 0.2 + X = np.random.randn(N, M) + A = np.random.randn(M) + noise = sigma * np.random.randn(N) + y = X @ A + noise + + with pm.Model() as lm: + weights = pm.Normal("weights", mu=0, sigma=1, shape=M) + noise = pm.Exponential("noise", lam=1) + y_observed = pm.Normal( + "y_observed", + mu=X @ weights, + sigma=noise, + observed=y + ) + + idata, _ = pm.quadratic_approximation([weights, noise]) + + hdi = az.hdi(idata) + weight_hdi = hdi.weights.values + assert np.all(np.bitwise_and(weight_hdi[0, :] < A, A < weight_hdi[1, :])) + assert hdi.noise.values[0] < sigma < hdi.noise.values[1] From 3f9a33a5b83ea602951fd62b9e927553bb86fcb4 Mon Sep 17 00:00:00 2001 From: Rasmus Berg Palm Date: Sat, 10 Jul 2021 09:43:35 +0200 Subject: [PATCH 2/6] Spelling Co-authored-by: Thomas Wiecki --- pymc3/quadratic_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py index 5273854141..591dc0b9d5 100644 --- a/pymc3/quadratic_approximation.py +++ b/pymc3/quadratic_approximation.py @@ -26,7 +26,7 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000): - """ Finds the quadratic approximation to the posterior, also known as the laplace approximation. + """ Finds the quadratic approximation to the posterior, also known as the Laplace approximation. NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables. The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation. From bee10e1bb512bde2eeb82dfd0b5c18ef1f169661 Mon Sep 17 00:00:00 2001 From: Rasmus Berg Palm Date: Sat, 10 Jul 2021 09:43:43 +0200 Subject: [PATCH 3/6] Update pymc3/quadratic_approximation.py Co-authored-by: Thomas Wiecki --- pymc3/quadratic_approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py index 591dc0b9d5..1b0caed133 100644 --- a/pymc3/quadratic_approximation.py +++ b/pymc3/quadratic_approximation.py @@ -1,4 +1,4 @@ -# Copyright 2020 The PyMC Developers +# Copyright 2021 The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f0e703006100c150dd1de15465a846cb4dcc1a9e Mon Sep 17 00:00:00 2001 From: rasmusbergpalm Date: Sat, 10 Jul 2021 09:56:47 +0200 Subject: [PATCH 4/6] PR comments --- RELEASE-NOTES.md | 1 + pymc3/quadratic_approximation.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index cd37f6e834..ca64e6ea7f 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -20,6 +20,7 @@ - Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)). - The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)). +- Adds Quadratic Approximation (see [#4847](https://github.com/pymc-devs/pymc3/pull/4847)). A very fast method to approximate the posterior with a Multivariate Normal. Only works well if the posterior is unimodal and roughly symmetrical. ### Maintenance - Remove float128 dtype support (see [#4514](https://github.com/pymc-devs/pymc3/pull/4514)). diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py index 1b0caed133..d8c9a7f991 100644 --- a/pymc3/quadratic_approximation.py +++ b/pymc3/quadratic_approximation.py @@ -49,9 +49,10 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000): Returns ------- - (arviz.InferenceData, scipy.stats.multivariate_normal): - InferenceData with samples from the approximate posterior, multivariate normal posterior approximation - + arviz.InferenceData: + InferenceData with samples from the approximate posterior + scipy.stats.multivariate_normal: + Multivariate normal posterior approximation """ map = find_MAP(vars=vars) H = find_hessian(map, vars=vars) From e15c0657762c94d2130a61f5aaa5bf36a7dbc30f Mon Sep 17 00:00:00 2001 From: rasmusbergpalm Date: Sat, 10 Jul 2021 10:05:24 +0200 Subject: [PATCH 5/6] linting --- pymc3/quadratic_approximation.py | 12 +++++------- pymc3/tests/test_quadratic_approximation.py | 10 +++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py index d8c9a7f991..89050d0389 100644 --- a/pymc3/quadratic_approximation.py +++ b/pymc3/quadratic_approximation.py @@ -14,19 +14,17 @@ """Functions for Quadratic Approximation.""" +import arviz as az import numpy as np import scipy -import arviz as az -__all__ = [ - "quadratic_approximation" -] +__all__ = ["quadratic_approximation"] -from pymc3.tuning import find_MAP, find_hessian +from pymc3.tuning import find_hessian, find_MAP def quadratic_approximation(vars, n_chains=2, n_samples=10_000): - """ Finds the quadratic approximation to the posterior, also known as the Laplace approximation. + """Finds the quadratic approximation to the posterior, also known as the Laplace approximation. NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables. The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation. @@ -64,6 +62,6 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000): i = 0 for v in vars: var_size = map[v.name].size - samples[v.name] = draws[:, :, i:i + var_size].squeeze() + samples[v.name] = draws[:, :, i : i + var_size].squeeze() i += var_size return az.convert_to_inference_data(samples), posterior diff --git a/pymc3/tests/test_quadratic_approximation.py b/pymc3/tests/test_quadratic_approximation.py index 14951cb21b..11981c1683 100644 --- a/pymc3/tests/test_quadratic_approximation.py +++ b/pymc3/tests/test_quadratic_approximation.py @@ -1,6 +1,7 @@ +import arviz as az import numpy as np + import pymc3 as pm -import arviz as az from pymc3.tests.helpers import SeededTest @@ -38,12 +39,7 @@ def test_hdi_contains_parameters_in_linear_regression(): with pm.Model() as lm: weights = pm.Normal("weights", mu=0, sigma=1, shape=M) noise = pm.Exponential("noise", lam=1) - y_observed = pm.Normal( - "y_observed", - mu=X @ weights, - sigma=noise, - observed=y - ) + y_observed = pm.Normal("y_observed", mu=X @ weights, sigma=noise, observed=y) idata, _ = pm.quadratic_approximation([weights, noise]) From c65147c7c04f9bc6aa28d7e2876224295da47961 Mon Sep 17 00:00:00 2001 From: rasmusbergpalm Date: Mon, 12 Jul 2021 08:41:47 +0200 Subject: [PATCH 6/6] Removes chains argument --- pymc3/quadratic_approximation.py | 12 +++++------- pymc3/tests/test_quadratic_approximation.py | 6 ++++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pymc3/quadratic_approximation.py b/pymc3/quadratic_approximation.py index 89050d0389..5a3c693bdf 100644 --- a/pymc3/quadratic_approximation.py +++ b/pymc3/quadratic_approximation.py @@ -23,16 +23,16 @@ from pymc3.tuning import find_hessian, find_MAP -def quadratic_approximation(vars, n_chains=2, n_samples=10_000): +def quadratic_approximation(vars, n_samples=10_000): """Finds the quadratic approximation to the posterior, also known as the Laplace approximation. NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables. - The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation. + The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all samples are from exactly the same distribution, the posterior quadratic approximation. Use at your own risk. See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background. - Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples, so the notion of chains is meaningless, and is only included for downstream compatibility with Arviz. + Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples. Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution. @@ -40,10 +40,8 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000): ---------- vars: list List of variables to approximate the posterior for. - n_chains: int - How many chains to simulate. n_samples: int - How many samples to sample from the approximate posterior for each chain. + How many samples to sample from the approximate posterior. Returns ------- @@ -57,7 +55,7 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000): cov = np.linalg.inv(H) mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov) - draws = posterior.rvs((n_chains, n_samples)) + draws = posterior.rvs(n_samples)[np.newaxis, ...] samples = {} i = 0 for v in vars: diff --git a/pymc3/tests/test_quadratic_approximation.py b/pymc3/tests/test_quadratic_approximation.py index 11981c1683..208ac6d1da 100644 --- a/pymc3/tests/test_quadratic_approximation.py +++ b/pymc3/tests/test_quadratic_approximation.py @@ -10,7 +10,9 @@ class TestQuadraticApproximation(SeededTest): def setup_method(self): super().setup_method() - def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance(): + def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance( + self, + ): y = np.array([2642, 3503, 4358]) n = y.size @@ -27,7 +29,7 @@ def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean assert np.allclose(posterior.mean, bda_map) assert np.allclose(posterior.cov, bda_cov, atol=1e-4) - def test_hdi_contains_parameters_in_linear_regression(): + def test_hdi_contains_parameters_in_linear_regression(self): N = 100 M = 2 sigma = 0.2