Skip to content

Commit 3486b36

Browse files
add new scatterplot function
1 parent 371bb28 commit 3486b36

File tree

7 files changed

+700
-597
lines changed

7 files changed

+700
-597
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
`p=tt.nnet.sigmoid(logit_p)`.
1010
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method
1111
which in turn makes sampling from a `DensityDist` possible.
12+
- Add new 'scatterplot' function, for plotting scatter or hexbin matrices of sampled parameters.
13+
Optionally it can plot divergences.
1214

1315
### Fixes
1416

docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 518 additions & 569 deletions
Large diffs are not rendered by default.

pymc3/plots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .traceplot import traceplot
77
from .energyplot import energyplot
88
from .densityplot import densityplot
9+
from .scatterplot import scatterplot

pymc3/plots/artists.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy.stats import mode
3+
from collections import OrderedDict
34

45
from pymc3.stats import hpd
56
from .kdeplot import fast_kde, kdeplot
@@ -144,3 +145,26 @@ def set_key_if_doesnt_exist(d, key, value):
144145
display_ref_val(ref_val)
145146
if rope is not None:
146147
display_rope(rope)
148+
149+
def scale_text(figsize, text_size):
150+
"""Scale text to figsize."""
151+
152+
if text_size is None and figsize is not None:
153+
if figsize[0] <= 11:
154+
return 12
155+
else:
156+
return figsize[0]
157+
else:
158+
return text_size
159+
160+
def get_trace_dict(tr, varnames):
161+
traces = OrderedDict()
162+
for v in varnames:
163+
vals = tr.get_values(v, combine=True, squeeze=True)
164+
if vals.ndim > 1:
165+
vals_flat = vals.reshape(vals.shape[0], -1).T
166+
for i, vi in enumerate(vals_flat):
167+
traces['_'.join([v, str(i)])] = vi
168+
else:
169+
traces[v] = vals
170+
return traces

pymc3/plots/posteriorplot.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
pass
77
import numpy as np
88

9-
from .artists import plot_posterior_op
9+
from .artists import plot_posterior_op, get_trace_dict, scale_text
10+
1011
from .utils import identity_transform, get_default_varnames
1112

1213

@@ -59,17 +60,6 @@ def plot_posterior(trace, varnames=None, transform=identity_transform, figsize=N
5960
6061
"""
6162

62-
def scale_text(figsize, text_size=text_size):
63-
"""Scale text to figsize."""
64-
65-
if text_size is None and figsize is not None:
66-
if figsize[0] <= 11:
67-
return 12
68-
else:
69-
return figsize[0]
70-
else:
71-
return text_size
72-
7363
def create_axes_grid(figsize, traces):
7464
l_trace = len(traces)
7565
if l_trace == 1:
@@ -85,18 +75,6 @@ def create_axes_grid(figsize, traces):
8575
ax = ax[:-1]
8676
return fig, ax
8777

88-
def get_trace_dict(tr, varnames):
89-
traces = OrderedDict()
90-
for v in varnames:
91-
vals = tr.get_values(v, combine=True, squeeze=True)
92-
if vals.ndim > 1:
93-
vals_flat = vals.reshape(vals.shape[0], -1).T
94-
for i, vi in enumerate(vals_flat):
95-
traces['_'.join([v, str(i)])] = vi
96-
else:
97-
traces[v] = vals
98-
return traces
99-
10078
if isinstance(trace, np.ndarray):
10179
if figsize is None:
10280
figsize = (6, 2)
@@ -106,7 +84,7 @@ def get_trace_dict(tr, varnames):
10684
plot_posterior_op(transform(trace), ax=ax, kde_plot=kde_plot,
10785
point_estimate=point_estimate, round_to=round_to,
10886
alpha_level=alpha_level, ref_val=ref_val, rope=rope,
109-
text_size=scale_text(figsize), **kwargs)
87+
text_size=scale_text(figsize, text_size), **kwargs)
11088
else:
11189
if varnames is None:
11290
varnames = get_default_varnames(trace.varnames, plot_transformed)
@@ -132,8 +110,8 @@ def get_trace_dict(tr, varnames):
132110
plot_posterior_op(tr_values, ax=a, kde_plot=kde_plot,
133111
point_estimate=point_estimate, round_to=round_to,
134112
alpha_level=alpha_level, ref_val=ref_val[idx],
135-
rope=rope[idx], text_size=scale_text(figsize), **kwargs)
136-
a.set_title(v, fontsize=scale_text(figsize))
113+
rope=rope[idx], text_size=scale_text(figsize, text_size), **kwargs)
114+
a.set_title(v, fontsize=scale_text(figsize, text_size))
137115

138116
plt.tight_layout()
139117
return ax

pymc3/plots/scatterplot.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import warnings
2+
3+
try:
4+
import matplotlib.pyplot as plt
5+
except ImportError: # mpl is optional
6+
pass
7+
import numpy as np
8+
from .utils import get_default_varnames
9+
from .artists import get_trace_dict, scale_text
10+
11+
12+
def scatterplot(trace, varnames=None, figsize=None, text_size=None,
13+
ax=None, hexbin=False, plot_transformed=False, divergences=False,
14+
kwargs_divergence=None, sub_varnames=None, **kwargs):
15+
16+
"""
17+
Plot a scatter or hexbin matrix of the sampled parameters.
18+
19+
Parameters
20+
----------
21+
22+
trace : result of MCMC run
23+
varnames : list of variable names
24+
Variables to be plotted, if None all variable are plotted
25+
figsize : figure size tuple
26+
If None, size is (8 + numvars, 8 + numvars)
27+
text_size: int
28+
Text size for labels
29+
ax : axes
30+
Matplotlib axes.
31+
hexbin : Boolean
32+
If True draws an hexbin plot
33+
plot_transformed : bool
34+
Flag for plotting automatically transformed variables in addition to
35+
original variables (defaults to False). Applies when varnames/sub_varnames = None.
36+
When a list of varnames/sub_varnames is passed, transformed variables can be passed
37+
using their names.
38+
divergences : Boolean
39+
If True divergences will be plotted in a diferent color
40+
kwargs_divergence : dicts, optional
41+
Aditional keywords passed to ax.scatter for divergences
42+
sub_varnames : list, optional
43+
Aditional varnames passed for plotting subsets of multidimensional variables
44+
Returns
45+
-------
46+
47+
ax : matplotlib axes
48+
49+
"""
50+
51+
if varnames is None:
52+
if plot_transformed:
53+
trace_dict = get_trace_dict(trace, get_default_varnames(trace.varnames, True))
54+
else:
55+
trace_dict = get_trace_dict(trace, get_default_varnames(trace.varnames, False))
56+
if sub_varnames is None:
57+
varnames = list(trace_dict.keys())
58+
else:
59+
varnames = sub_varnames
60+
else:
61+
trace_dict = get_trace_dict(trace, varnames)
62+
varnames = list(trace_dict.keys())
63+
64+
if text_size is None:
65+
text_size = scale_text(figsize, text_size=text_size)
66+
67+
if kwargs_divergence is None:
68+
kwargs_divergence = {}
69+
70+
numvars = len(varnames)
71+
72+
if figsize is None:
73+
figsize = (8+numvars, 8+numvars)
74+
75+
if numvars < 2:
76+
raise Exception('Number of variables to be plotted must be 2 or greater.')
77+
78+
if numvars == 2 and ax is not None:
79+
if hexbin:
80+
ax.hexbin(trace_dict[varnames[0]], trace_dict[varnames[1]], mincnt=1, **kwargs)
81+
else:
82+
ax.scatter(trace_dict[varnames[0]], trace_dict[varnames[1]], **kwargs)
83+
84+
if divergences:
85+
try:
86+
divergent = trace['diverging']
87+
if np.any(divergent):
88+
ax.scatter(trace_dict[varnames[0]][divergent == 1], trace_dict[varnames[1]][divergent == 1], **kwargs_divergence)
89+
else:
90+
print('No divergences were found.')
91+
except KeyError:
92+
warnings.warn('There is no divergence information in the passed trace.')
93+
return ax
94+
95+
if ax is None:
96+
_, ax = plt.subplots(nrows=numvars, ncols=numvars, figsize=figsize)
97+
98+
for i in range(numvars):
99+
var1 = trace_dict[varnames[i]]
100+
for j in range(i, numvars):
101+
var2 = trace_dict[varnames[j]]
102+
103+
if i == j:
104+
ax[i, j].axes.remove()
105+
106+
else:
107+
if hexbin:
108+
ax[j, i].hexbin(var1, var2, mincnt=1, **kwargs)
109+
else:
110+
ax[j, i].scatter(var1, var2, **kwargs)
111+
112+
if divergences:
113+
try:
114+
divergent = trace['diverging']
115+
if np.any(divergent):
116+
ax[j, i].scatter(var1[divergent == 1], var2[divergent == 1], **kwargs_divergence)
117+
else:
118+
print('No divergences were found.')
119+
except KeyError:
120+
warnings.warn('There is no divergence information in the passed trace.')
121+
return ax
122+
123+
ax[i, j].axes.remove()
124+
125+
if j != numvars-1:
126+
ax[j, i].set_xticks([])
127+
if i != 0:
128+
ax[j, i].set_yticks([])
129+
130+
ax[numvars-1, j].set_xlabel('{}'.format(varnames[i]), fontsize=text_size)
131+
ax[j, 0].set_ylabel('{}'.format(varnames[j]), fontsize=text_size)
132+
ax[numvars-1, j].tick_params(labelsize=text_size)
133+
ax[j, 0].tick_params(labelsize=text_size)
134+
135+
plt.tight_layout()
136+
return ax

pymc3/tests/test_plots.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .checks import close_to
77

88
from .models import multidimensional_model, simple_categorical
9-
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot
9+
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot, scatterplot
1010
from ..plots.utils import make_2d
1111
from ..step_methods import Slice, Metropolis
1212
from ..sampling import sample
@@ -66,6 +66,7 @@ def test_plots_multidimensional():
6666
forestplot(trace)
6767
densityplot(trace)
6868

69+
6970
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on GPU due to cores=2")
7071
def test_multichain_plots():
7172
model = build_disaster_model()
@@ -118,3 +119,15 @@ def test_plots_transformed():
118119
assert autocorrplot(trace, plot_transformed=True).shape == (2, 2)
119120
assert plot_posterior(trace).numCols == 1
120121
assert plot_posterior(trace, plot_transformed=True).shape == (2, )
122+
123+
def test_scatterplot():
124+
with pm.Model() as model:
125+
a = pm.Normal('a', shape=2)
126+
c = pm.HalfNormal('c', shape=2)
127+
b = pm.Normal('b', a, c, shape=2)
128+
d = pm.Normal('d', 100, 1)
129+
trace = pm.sample(1000)
130+
131+
scatterplot(trace)
132+
scatterplot(trace, hexbin=True, plot_transformed=True)
133+
scatterplot(trace, sub_varnames=['a_0', 'c_0', 'b_1'])

0 commit comments

Comments
 (0)