Skip to content

Latest commit

 

History

History
124 lines (100 loc) · 3.94 KB

gaussian_mixture_model.myst.md

File metadata and controls

124 lines (100 loc) · 3.94 KB
jupytext kernelspec
text_representation
extension format_name format_version jupytext_version
.md
myst
0.13
1.13.7
display_name language name
pymc-dev-py39
python
pymc-dev-py39

(gaussian_mixture_model)=

Gaussian Mixture Model

:::{post} April, 2022 :tags: mixture model, classification :category: beginner :author: Abe Flaxman :::

A mixture model allows us to make inferences about the component contributors to a distribution of data. More specifically, a Gaussian Mixture Model allows us to make inferences about the means and standard deviations of a specified number of underlying component Gaussian distributions.

This could be useful in a number of ways. For example, we may be interested in simply describing a complex distribution parametrically (i.e. a mixture distribution). Alternatively, we may be interested in classification where we seek to probabilistically classify which of a number of classes a particular observation is from.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm

from scipy.stats import norm
from xarray_einstats.stats import XrContinuousRV
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

First we generate some simulated observations.

:tags: [hide-input]

k = 3
ndata = 500
centers = np.array([-5, 0, 5])
sds = np.array([0.5, 2.0, 0.75])
idx = rng.integers(0, k, ndata)
x = rng.normal(loc=centers[idx], scale=sds[idx], size=ndata)
plt.hist(x, 40);

In the PyMC model, we will estimate one $\mu$ and one $\sigma$ for each of the 3 clusters. Writing a Gaussian Mixture Model is very easy with the pm.NormalMixture distribution.

with pm.Model(coords={"cluster": range(k)}) as model:
    μ = pm.Normal(
        "μ",
        mu=0,
        sigma=5,
        transform=pm.distributions.transforms.ordered,
        initval=[-4, 0, 4],
        dims="cluster",
    )
    σ = pm.HalfNormal("σ", sigma=1, dims="cluster")
    weights = pm.Dirichlet("w", np.ones(k), dims="cluster")
    pm.NormalMixture("x", w=weights, mu=μ, sigma=σ, observed=x)

pm.model_to_graphviz(model)
with model:
    idata = pm.sample()

We can also plot the trace to check the nature of the MCMC chains, and compare to the ground truth values.

az.plot_trace(idata, var_names=["μ", "σ"], lines=[("μ", {}, [centers]), ("σ", {}, [sds])]);

And if we wanted, we could calculate the probability density function and examine the estimated group membership probabilities, based on the posterior mean estimates.

xi = np.linspace(-7, 7, 500)
post = idata.posterior
pdf_components = XrContinuousRV(norm, post["μ"], post["σ"]).pdf(xi) * post["w"]
pdf = pdf_components.sum("cluster")

fig, ax = plt.subplots(3, 1, figsize=(7, 8), sharex=True)
# empirical histogram
ax[0].hist(x, 50)
ax[0].set(title="Data", xlabel="x", ylabel="Frequency")
# pdf
pdf_components.mean(dim=["chain", "draw"]).sum("cluster").plot.line(ax=ax[1])
ax[1].set(title="PDF", xlabel="x", ylabel="Probability\ndensity")
# plot group membership probabilities
(pdf_components / pdf).mean(dim=["chain", "draw"]).plot.line(hue="cluster", ax=ax[2])
ax[2].set(title="Group membership", xlabel="x", ylabel="Probability");

Authors

  • Authored by Abe Flaxman.
  • Updated by Thomas Wiecki.
  • Updated by Benjamin T. Vincent in April 2022 (#310) to use pm.NormalMixture.

+++

Watermark

%load_ext watermark
%watermark -n -u -v -iv -w -p aesara,aeppl,xarray,xarray_einstats

:::{include} ../page_footer.md :::