Skip to content

Commit be39d3d

Browse files
followed sugestions made in #2861
1 parent 3486b36 commit be39d3d

File tree

8 files changed

+500
-904
lines changed

8 files changed

+500
-904
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
`p=tt.nnet.sigmoid(logit_p)`.
1010
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method
1111
which in turn makes sampling from a `DensityDist` possible.
12-
- Add new 'scatterplot' function, for plotting scatter or hexbin matrices of sampled parameters.
12+
- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters.
1313
Optionally it can plot divergences.
1414

1515
### Fixes

docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 328 additions & 759 deletions
Large diffs are not rendered by default.

pymc3/plots/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .traceplot import traceplot
77
from .energyplot import energyplot
88
from .densityplot import densityplot
9-
from .scatterplot import scatterplot
9+
from .pairplot import pairplot

pymc3/plots/artists.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,5 @@ def get_trace_dict(tr, varnames):
167167
traces['_'.join([v, str(i)])] = vi
168168
else:
169169
traces[v] = vals
170-
return traces
170+
return traces
171+

pymc3/plots/pairplot.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import warnings
2+
import matplotlib.gridspec as gridspec
3+
4+
try:
5+
import matplotlib.pyplot as plt
6+
except ImportError: # mpl is optional
7+
pass
8+
from ..util import get_default_varnames, is_transformed_name, get_untransformed_name
9+
from .artists import get_trace_dict, scale_text
10+
11+
12+
def pairplot(trace, varnames=None, figsize=None, text_size=None,
13+
gs=None, ax=None, hexbin=False, plot_transformed=False,
14+
divergences=False, kwargs_divergence=None,
15+
sub_varnames=None, **kwargs):
16+
17+
"""
18+
Plot a scatter or hexbin matrix of the sampled parameters.
19+
20+
Parameters
21+
----------
22+
23+
trace : result of MCMC run
24+
varnames : list of variable names
25+
Variables to be plotted, if None all variable are plotted
26+
figsize : figure size tuple
27+
If None, size is (8 + numvars, 8 + numvars)
28+
text_size: int
29+
Text size for labels
30+
gs : Grid spec
31+
Matplotlib Grid spec.
32+
ax: axes
33+
Matplotlib axes
34+
hexbin : Boolean
35+
If True draws an hexbin plot
36+
plot_transformed : bool
37+
Flag for plotting automatically transformed variables in addition to
38+
original variables (defaults to False). Applies when varnames = None.
39+
When a list of varnames is passed, transformed variables can be passed
40+
using their names.
41+
divergences : Boolean
42+
If True divergences will be plotted in a diferent color
43+
kwargs_divergence : dicts, optional
44+
Aditional keywords passed to ax.scatter for divergences
45+
sub_varnames : list, optional
46+
Aditional varnames passed for plotting subsets of multidimensional
47+
variables
48+
Returns
49+
-------
50+
51+
ax : matplotlib axes
52+
53+
"""
54+
55+
if varnames is None:
56+
if plot_transformed:
57+
58+
remove = [get_untransformed_name(var) for var in trace.varnames
59+
if is_transformed_name(var)]
60+
varnames_copy = trace.varnames
61+
62+
try:
63+
[varnames_copy.remove(i) for i in remove]
64+
varnames = varnames_copy
65+
except:
66+
varnames = trace.varnames
67+
68+
trace_dict = get_trace_dict(trace,
69+
get_default_varnames(varnames, True))
70+
71+
else:
72+
trace_dict = get_trace_dict(trace,
73+
get_default_varnames(trace.varnames, False))
74+
75+
if sub_varnames is None:
76+
varnames = list(trace_dict.keys())
77+
78+
else:
79+
trace_dict = get_trace_dict(trace,
80+
get_default_varnames(trace.varnames, True))
81+
varnames = sub_varnames
82+
83+
else:
84+
trace_dict = get_trace_dict(trace, varnames)
85+
varnames = list(trace_dict.keys())
86+
87+
if text_size is None:
88+
text_size = scale_text(figsize, text_size=text_size)
89+
90+
if kwargs_divergence is None:
91+
kwargs_divergence = {}
92+
93+
numvars = len(varnames)
94+
95+
if figsize is None:
96+
figsize = (8+numvars, 8+numvars)
97+
98+
if numvars < 2:
99+
raise Exception('Number of variables to be plotted must be 2 or greater.')
100+
101+
if numvars == 2 and ax is not None:
102+
if hexbin:
103+
ax.hexbin(trace_dict[varnames[0]],
104+
trace_dict[varnames[1]], mincnt=1, **kwargs)
105+
else:
106+
ax.scatter(trace_dict[varnames[0]],
107+
trace_dict[varnames[1]], **kwargs)
108+
109+
if divergences:
110+
try:
111+
divergent = trace['diverging']
112+
except KeyError:
113+
warnings.warn('No divergences were found.')
114+
115+
diverge = (divergent == 1)
116+
ax.scatter(trace_dict[varnames[0]][diverge],
117+
trace_dict[varnames[1]][diverge], **kwargs_divergence)
118+
119+
if gs is None and ax is None:
120+
121+
gs = gridspec.GridSpec(numvars -1, numvars-1)
122+
123+
for i in range(0, numvars-1):
124+
var1 = trace_dict[varnames[i]]
125+
126+
for j in range(i, numvars-1):
127+
var2 = trace_dict[varnames[j+1]]
128+
129+
ax = plt.subplot(gs[j, i])
130+
131+
if hexbin:
132+
ax.hexbin(var1, var2, mincnt=1, **kwargs)
133+
else:
134+
ax.scatter(var1, var2, **kwargs)
135+
136+
if divergences:
137+
try:
138+
divergent = trace['diverging']
139+
except KeyError:
140+
warnings.warn('No divergences were found.')
141+
return ax
142+
143+
diverge = (divergent == 1)
144+
ax.scatter(trace_dict[varnames[0]][diverge],
145+
trace_dict[varnames[1]][diverge],
146+
**kwargs_divergence)
147+
148+
if j+1 != numvars-1:
149+
ax.set_xticks([])
150+
else:
151+
ax.set_xlabel('{}'.format(varnames[i]),
152+
fontsize=text_size)
153+
if i != 0:
154+
ax.set_yticks([])
155+
else:
156+
ax.set_ylabel('{}'.format(varnames[j+1]), fontsize=text_size)
157+
158+
ax.tick_params(labelsize=text_size)
159+
ax.tick_params(labelsize=text_size)
160+
161+
plt.tight_layout()
162+
return ax

pymc3/plots/scatterplot.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

pymc3/plots/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
pass
55
import numpy as np
66
# plotting utilities can all be in this namespace
7-
from ..util import get_default_varnames # pylint: disable=unused-import
7+
from ..util import get_default_varnames # pylint: disable=unused-import
88

99

1010
def identity_transform(x):

pymc3/tests/test_plots.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .checks import close_to
77

88
from .models import multidimensional_model, simple_categorical
9-
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot, scatterplot
9+
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot, pairplot
1010
from ..plots.utils import make_2d
1111
from ..step_methods import Slice, Metropolis
1212
from ..sampling import sample
@@ -120,14 +120,14 @@ def test_plots_transformed():
120120
assert plot_posterior(trace).numCols == 1
121121
assert plot_posterior(trace, plot_transformed=True).shape == (2, )
122122

123-
def test_scatterplot():
123+
def test_pairplot():
124124
with pm.Model() as model:
125125
a = pm.Normal('a', shape=2)
126126
c = pm.HalfNormal('c', shape=2)
127127
b = pm.Normal('b', a, c, shape=2)
128128
d = pm.Normal('d', 100, 1)
129129
trace = pm.sample(1000)
130130

131-
scatterplot(trace)
132-
scatterplot(trace, hexbin=True, plot_transformed=True)
133-
scatterplot(trace, sub_varnames=['a_0', 'c_0', 'b_1'])
131+
pairplot(trace)
132+
pairplot(trace, hexbin=True, plot_transformed=True)
133+
pairplot(trace, sub_varnames=['a_0', 'c_0', 'b_1'])

0 commit comments

Comments
 (0)