Skip to content

Commit 75d1ce0

Browse files
authored
add densityplot (#2741)
* add densityplot * add new default style and new options, add test and update release notes and example
1 parent ee7c5bc commit 75d1ce0

File tree

6 files changed

+292
-90
lines changed

6 files changed

+292
-90
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
algorithm (#2544)
1515
- Michael Osthege added support for population-samplers and implemented differential evolution metropolis (`DEMetropolis`). For models with correlated dimensions that can not use gradient-based samplers, the `DEMetropolis` sampler can give higher effective sampling rates. (also see [PR#2735](https://github.com/pymc-devs/pymc3/pull/2735))
1616
- Forestplot supports multiple traces (#2736)
17+
- Add new plot, densityplot (#2741)
1718

1819
### Fixes
1920

docs/source/api/plots.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ Plots
66

77
.. automodule:: pymc3.plots
88
:members: traceplot, plot_posterior, forestplot, compareplot, autocorrplot,
9-
energyplot, kdeplot
9+
energyplot, kdeplot, densityplot

docs/source/notebooks/model_averaging.ipynb

Lines changed: 113 additions & 87 deletions
Large diffs are not rendered by default.

pymc3/plots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .posteriorplot import plot_posterior, plot_posterior_predictive_glm
66
from .traceplot import traceplot
77
from .energyplot import energyplot
8+
from .densityplot import densityplot

pymc3/plots/densityplot.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import numpy as np
2+
try:
3+
import matplotlib.pyplot as plt
4+
except ImportError: # mpl is optional
5+
pass
6+
from .kdeplot import fast_kde
7+
from .utils import get_default_varnames
8+
from ..stats import hpd
9+
10+
11+
def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='mean',
12+
colors='cycle', outline=True, hpd_markers='', shade=0., figsize=None, textsize=12,
13+
plot_transformed=False, ax=None):
14+
"""
15+
Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of
16+
traces. KDE plots are grouped per variable and colors assigned to models.
17+
18+
Parameters
19+
----------
20+
trace : trace or list of traces
21+
Trace(s) from an MCMC sample.
22+
models : list
23+
List with names for the models in the list of traces. Useful when
24+
plotting more that one trace.
25+
varnames: list
26+
List of variables to plot (defaults to None, which results in all
27+
variables plotted).
28+
alpha : float
29+
Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
30+
point_estimate : str or None
31+
Plot point estimate per variable. Values should be 'mean', 'median' or None.
32+
Defaults to 'mean'.
33+
colors : list or string, optional
34+
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
35+
If the string is `cycle `, it will automatically choose a color per model from matplolib's
36+
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
37+
models. Defaults to 'C0' (blueish in most matplotlib styles)
38+
outline : boolean
39+
Use a line to draw the truncated KDE and. Defaults to True
40+
hpd_markers : str
41+
A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval.
42+
Defaults to empty string (no marker).
43+
shade : float
44+
Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
45+
(opaque). Defaults to 0.
46+
figsize : tuple
47+
Figure size. If None, size is (6, number of variables * 2)
48+
textsize : int
49+
Text size of the legend. Default 12.
50+
plot_transformed : bool
51+
Flag for plotting automatically transformed variables in addition to original variables
52+
Defaults to False.
53+
ax : axes
54+
Matplotlib axes.
55+
56+
Returns
57+
-------
58+
59+
ax : Matplotlib axes
60+
61+
"""
62+
if point_estimate not in ('mean', 'median', None):
63+
raise ValueError("Point estimate should be 'mean' or 'median'")
64+
65+
if not isinstance(trace, (list, tuple)):
66+
trace = [trace]
67+
68+
lenght_trace = len(trace)
69+
70+
if models is None:
71+
if lenght_trace > 1:
72+
models = ['m_{}'.format(i) for i in range(lenght_trace)]
73+
else:
74+
models = ['']
75+
elif len(models) != lenght_trace:
76+
raise ValueError("The number of names for the models does not match the number of models")
77+
78+
lenght_models = len(models)
79+
80+
if colors == 'cycle':
81+
colors = ['C{}'.format(i % 10) for i in range(lenght_models)]
82+
elif isinstance(colors, str):
83+
colors = [colors for i in range(lenght_models)]
84+
85+
if varnames is None:
86+
varnames = []
87+
for tr in trace:
88+
varnames_tmp = get_default_varnames(tr.varnames, plot_transformed)
89+
for v in varnames_tmp:
90+
if v not in varnames:
91+
varnames.append(v)
92+
93+
if figsize is None:
94+
figsize = (6, len(varnames) * 2)
95+
96+
fig, kplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
97+
kplot = kplot.flatten()
98+
99+
for v_idx, vname in enumerate(varnames):
100+
for t_idx, tr in enumerate(trace):
101+
if vname in tr.varnames:
102+
vec = tr.get_values(vname)
103+
k = np.size(vec[0])
104+
if k > 1:
105+
vec = np.split(vec.T.ravel(), k)
106+
for i in range(k):
107+
_kde_helper(vec[i], vname, colors[t_idx], alpha, point_estimate,
108+
hpd_markers, outline, shade, kplot[v_idx])
109+
else:
110+
_kde_helper(vec, vname, colors[t_idx], alpha, point_estimate,
111+
hpd_markers, outline, shade, kplot[v_idx])
112+
113+
if lenght_trace > 1:
114+
for m_idx, m in enumerate(models):
115+
kplot[0].plot([], label=m, c=colors[m_idx])
116+
kplot[0].legend(fontsize=textsize)
117+
118+
fig.tight_layout()
119+
120+
return kplot
121+
122+
123+
def _kde_helper(vec, vname, c, alpha, point_estimate, hpd_markers,
124+
outline, shade, ax):
125+
"""
126+
vec : array
127+
1D array from trace
128+
vname : str
129+
variable name
130+
c : str
131+
matplotlib color
132+
alpha : float
133+
Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
134+
point_estimate : str or None
135+
'mean' or 'median'
136+
shade : float
137+
Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
138+
(opaque). Defaults to 0.
139+
ax : matplotlib axes
140+
"""
141+
density, l, u = fast_kde(vec)
142+
x = np.linspace(l, u, len(density))
143+
hpd_ = hpd(vec, alpha)
144+
cut = (x >= hpd_[0]) & (x <= hpd_[1])
145+
146+
xmin = x[cut][0]
147+
xmax = x[cut][-1]
148+
ymin = density[cut][0]
149+
ymax = density[cut][-1]
150+
151+
if outline:
152+
ax.plot(x[cut], density[cut], color=c)
153+
ax.plot([xmin, xmin], [-0.5, ymin], color=c, ls='-')
154+
ax.plot([xmax, xmax], [-0.5, ymax], color=c, ls='-')
155+
156+
if hpd_markers:
157+
ax.plot(xmin, 0, 'v', color=c, markeredgecolor='k')
158+
ax.plot(xmax, 0, 'v', color=c, markeredgecolor='k')
159+
160+
if shade:
161+
ax.fill_between(x, density, where=cut, color=c, alpha=shade)
162+
163+
if point_estimate is not None:
164+
if point_estimate == 'mean':
165+
ps = np.mean(vec)
166+
if point_estimate == 'median':
167+
ps = np.median(vec)
168+
ax.plot(ps, 0, 'o', color=c, markeredgecolor='k')
169+
170+
ax.set_yticks([])
171+
ax.set_title(vname)
172+
for pos in ['left', 'right', 'top']:
173+
ax.spines[pos].set_visible(0)

pymc3/tests/test_plots.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .checks import close_to
77

88
from .models import multidimensional_model, simple_categorical
9-
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot
9+
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot
1010
from ..plots.utils import make_2d
1111
from ..step_methods import Slice, Metropolis
1212
from ..sampling import sample
@@ -30,7 +30,7 @@ def test_plots():
3030
plot_posterior(trace)
3131
autocorrplot(trace)
3232
energyplot(trace)
33-
33+
densityplot(trace)
3434

3535
def test_energyplot():
3636
with asmod.build_model():
@@ -64,6 +64,7 @@ def test_plots_multidimensional():
6464
traceplot(trace)
6565
plot_posterior(trace)
6666
forestplot(trace)
67+
densityplot(trace)
6768

6869
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on GPU due to njobs=2")
6970
def test_multichain_plots():

0 commit comments

Comments
 (0)