Skip to content

Commit 553f057

Browse files
agustinaarroyueloColCarroll
authored andcommitted
Add scatterplot function (#2861)
* add new scatterplot function * followed sugestions made in #2861 * run Divergences notebook with examples * add carriage return * fix plot_transformed argument * fix fig_size * fix gridspec import error and minor issues * remove unused module import line
1 parent 4b3620c commit 553f057

File tree

8 files changed

+843
-659
lines changed

8 files changed

+843
-659
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
### New features
66

7-
<<<<<<< refs/remotes/pymc-devs/master
87
- Add `logit_p` keyword to `pm.Bernoulli`, so that users can specify the logit of the success probability. This is faster and more stable than using `p=tt.nnet.sigmoid(logit_p)`.
98
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method which in turn makes sampling from a `DensityDist` possible.
109
- Effective sample size computation is updated. The estimation uses Geyer's initial positive sequence, which no longer truncates the autocorrelation series inaccurately. `pm.diagnostics.effective_n` now can reports N_eff>N.
1110
- Added `KroneckerNormal` distribution and a corresponding `MarginalKron`
1211
Gaussian Process implementation for efficient inference, along with
1312
lower-level functions such as `cartesian` and `kronecker` products.
1413
- Added `Coregion` covariance function.
14+
- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters.
15+
Optionally it can plot divergences.
1516
- Plots of discrete distributions in the docstrings
1617

1718
### Fixes
@@ -31,7 +32,6 @@ deviation. This works better for multimodal distributions. Functions using KDE p
3132

3233
- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use
3334
log_i0 to compute the logp.
34-
>>>>>>> Changes in test for random method in DensityDist
3535

3636
### Deprecations
3737

docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Lines changed: 623 additions & 627 deletions
Large diffs are not rendered by default.

pymc3/plots/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .traceplot import traceplot
77
from .energyplot import energyplot
88
from .densityplot import densityplot
9+
from .pairplot import pairplot

pymc3/plots/artists.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy.stats import mode
3+
from collections import OrderedDict
34

45
from pymc3.stats import hpd
56
from .kdeplot import fast_kde, kdeplot
@@ -144,3 +145,27 @@ def set_key_if_doesnt_exist(d, key, value):
144145
display_ref_val(ref_val)
145146
if rope is not None:
146147
display_rope(rope)
148+
149+
def scale_text(figsize, text_size):
150+
"""Scale text to figsize."""
151+
152+
if text_size is None and figsize is not None:
153+
if figsize[0] <= 11:
154+
return 12
155+
else:
156+
return figsize[0]
157+
else:
158+
return text_size
159+
160+
def get_trace_dict(tr, varnames):
161+
traces = OrderedDict()
162+
for v in varnames:
163+
vals = tr.get_values(v, combine=True, squeeze=True)
164+
if vals.ndim > 1:
165+
vals_flat = vals.reshape(vals.shape[0], -1).T
166+
for i, vi in enumerate(vals_flat):
167+
traces['_'.join([v, str(i)])] = vi
168+
else:
169+
traces[v] = vals
170+
return traces
171+

pymc3/plots/pairplot.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import warnings
2+
3+
try:
4+
import matplotlib.pyplot as plt
5+
from matplotlib import gridspec
6+
except ImportError: # mpl is optional
7+
pass
8+
from ..util import get_default_varnames, is_transformed_name, get_untransformed_name
9+
from .artists import get_trace_dict, scale_text
10+
11+
12+
def pairplot(trace, varnames=None, figsize=None, text_size=None,
13+
gs=None, ax=None, hexbin=False, plot_transformed=False,
14+
divergences=False, kwargs_divergence=None,
15+
sub_varnames=None, **kwargs):
16+
"""
17+
Plot a scatter or hexbin matrix of the sampled parameters.
18+
19+
Parameters
20+
----------
21+
22+
trace : result of MCMC run
23+
varnames : list of variable names
24+
Variables to be plotted, if None all variable are plotted
25+
figsize : figure size tuple
26+
If None, size is (8 + numvars, 8 + numvars)
27+
text_size: int
28+
Text size for labels
29+
gs : Grid spec
30+
Matplotlib Grid spec.
31+
ax: axes
32+
Matplotlib axes
33+
hexbin : Boolean
34+
If True draws an hexbin plot
35+
plot_transformed : bool
36+
Flag for plotting automatically transformed variables in addition to
37+
original variables (defaults to False). Applies when varnames = None.
38+
When a list of varnames is passed, transformed variables can be passed
39+
using their names.
40+
divergences : Boolean
41+
If True divergences will be plotted in a diferent color
42+
kwargs_divergence : dicts, optional
43+
Aditional keywords passed to ax.scatter for divergences
44+
sub_varnames : list
45+
Aditional varnames passed for plotting subsets of multidimensional
46+
variables
47+
Returns
48+
-------
49+
50+
ax : matplotlib axes
51+
gs : matplotlib gridspec
52+
53+
"""
54+
if varnames is None:
55+
if plot_transformed:
56+
57+
varnames_copy = list(trace.varnames)
58+
remove = [get_untransformed_name(var) for var in trace.varnames
59+
if is_transformed_name(var)]
60+
61+
try:
62+
[varnames_copy.remove(i) for i in remove]
63+
varnames = varnames_copy
64+
except ValueError:
65+
varnames = varnames_copy
66+
67+
trace_dict = get_trace_dict(
68+
trace, get_default_varnames(
69+
varnames, plot_transformed))
70+
71+
else:
72+
trace_dict = get_trace_dict(
73+
trace, get_default_varnames(
74+
trace.varnames, plot_transformed))
75+
76+
if sub_varnames is None:
77+
varnames = list(trace_dict.keys())
78+
79+
else:
80+
trace_dict = get_trace_dict(
81+
trace, get_default_varnames(
82+
trace.varnames, True))
83+
varnames = sub_varnames
84+
85+
else:
86+
trace_dict = get_trace_dict(trace, varnames)
87+
varnames = list(trace_dict.keys())
88+
89+
if text_size is None:
90+
text_size = scale_text(figsize, text_size=text_size)
91+
92+
if kwargs_divergence is None:
93+
kwargs_divergence = {}
94+
95+
numvars = len(varnames)
96+
97+
if figsize is None:
98+
figsize = (8 + numvars, 8 + numvars)
99+
100+
if numvars < 2:
101+
raise Exception(
102+
'Number of variables to be plotted must be 2 or greater.')
103+
104+
if numvars == 2 and ax is not None:
105+
if hexbin:
106+
ax.hexbin(trace_dict[varnames[0]],
107+
trace_dict[varnames[1]], mincnt=1, **kwargs)
108+
else:
109+
ax.scatter(trace_dict[varnames[0]],
110+
trace_dict[varnames[1]], **kwargs)
111+
112+
if divergences:
113+
try:
114+
divergent = trace['diverging']
115+
except KeyError:
116+
warnings.warn('No divergences were found.')
117+
118+
diverge = (divergent == 1)
119+
ax.scatter(trace_dict[varnames[0]][diverge],
120+
trace_dict[varnames[1]][diverge], **kwargs_divergence)
121+
ax.set_xlabel('{}'.format(varnames[0]),
122+
fontsize=text_size)
123+
ax.set_ylabel('{}'.format(
124+
varnames[1]), fontsize=text_size)
125+
ax.tick_params(labelsize=text_size)
126+
127+
if gs is None and ax is None:
128+
plt.figure(figsize=figsize)
129+
gs = gridspec.GridSpec(numvars - 1, numvars - 1)
130+
131+
for i in range(0, numvars - 1):
132+
var1 = trace_dict[varnames[i]]
133+
134+
for j in range(i, numvars - 1):
135+
var2 = trace_dict[varnames[j + 1]]
136+
137+
ax = plt.subplot(gs[j, i])
138+
139+
if hexbin:
140+
ax.hexbin(var1, var2, mincnt=1, **kwargs)
141+
else:
142+
ax.scatter(var1, var2, **kwargs)
143+
144+
if divergences:
145+
try:
146+
divergent = trace['diverging']
147+
except KeyError:
148+
warnings.warn('No divergences were found.')
149+
return ax
150+
151+
diverge = (divergent == 1)
152+
ax.scatter(var1[diverge],
153+
var2[diverge],
154+
**kwargs_divergence)
155+
156+
if j + 1 != numvars - 1:
157+
ax.set_xticks([])
158+
else:
159+
ax.set_xlabel('{}'.format(varnames[i]),
160+
fontsize=text_size)
161+
if i != 0:
162+
ax.set_yticks([])
163+
else:
164+
ax.set_ylabel('{}'.format(
165+
varnames[j + 1]), fontsize=text_size)
166+
167+
ax.tick_params(labelsize=text_size)
168+
169+
plt.tight_layout()
170+
return ax, gs

pymc3/plots/posteriorplot.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from collections import OrderedDict
2-
31
try:
42
import matplotlib.pyplot as plt
53
except ImportError: # mpl is optional
64
pass
75
import numpy as np
86

9-
from .artists import plot_posterior_op
7+
from .artists import plot_posterior_op, get_trace_dict, scale_text
8+
109
from .utils import identity_transform, get_default_varnames
1110

1211

@@ -63,17 +62,6 @@ def plot_posterior(trace, varnames=None, transform=identity_transform, figsize=N
6362
6463
"""
6564

66-
def scale_text(figsize, text_size=text_size):
67-
"""Scale text to figsize."""
68-
69-
if text_size is None and figsize is not None:
70-
if figsize[0] <= 11:
71-
return 12
72-
else:
73-
return figsize[0]
74-
else:
75-
return text_size
76-
7765
def create_axes_grid(figsize, traces):
7866
l_trace = len(traces)
7967
if l_trace == 1:
@@ -89,27 +77,17 @@ def create_axes_grid(figsize, traces):
8977
ax = ax[:-1]
9078
return fig, ax
9179

92-
def get_trace_dict(tr, varnames):
93-
traces = OrderedDict()
94-
for v in varnames:
95-
vals = tr.get_values(v, combine=True, squeeze=True)
96-
if vals.ndim > 1:
97-
vals_flat = vals.reshape(vals.shape[0], -1).T
98-
for i, vi in enumerate(vals_flat):
99-
traces['_'.join([v, str(i)])] = vi
100-
else:
101-
traces[v] = vals
102-
return traces
103-
10480
if isinstance(trace, np.ndarray):
10581
if figsize is None:
10682
figsize = (6, 2)
10783
if ax is None:
10884
fig, ax = plt.subplots(figsize=figsize)
10985

86+
11087
plot_posterior_op(transform(trace), ax=ax, bw=bw, kde_plot=kde_plot,
11188
point_estimate=point_estimate, round_to=round_to, alpha_level=alpha_level,
11289
ref_val=ref_val, rope=rope, text_size=scale_text(figsize), **kwargs)
90+
11391
else:
11492
if varnames is None:
11593
varnames = get_default_varnames(trace.varnames, plot_transformed)
@@ -135,8 +113,8 @@ def get_trace_dict(tr, varnames):
135113
plot_posterior_op(tr_values, ax=a, bw=bw, kde_plot=kde_plot,
136114
point_estimate=point_estimate, round_to=round_to,
137115
alpha_level=alpha_level, ref_val=ref_val[idx],
138-
rope=rope[idx], text_size=scale_text(figsize), **kwargs)
139-
a.set_title(v, fontsize=scale_text(figsize))
116+
rope=rope[idx], text_size=scale_text(figsize, text_size), **kwargs)
117+
a.set_title(v, fontsize=scale_text(figsize, text_size))
140118

141119
plt.tight_layout()
142120
return ax

pymc3/plots/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
pass
55
import numpy as np
66
# plotting utilities can all be in this namespace
7-
from ..util import get_default_varnames # pylint: disable=unused-import
7+
from ..util import get_default_varnames # pylint: disable=unused-import
88

99

1010
def identity_transform(x):

pymc3/tests/test_plots.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .checks import close_to
77

88
from .models import multidimensional_model, simple_categorical
9-
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot
9+
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot, pairplot
1010
from ..plots.utils import make_2d
1111
from ..step_methods import Slice, Metropolis
1212
from ..sampling import sample
@@ -66,6 +66,7 @@ def test_plots_multidimensional():
6666
forestplot(trace)
6767
densityplot(trace)
6868

69+
6970
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on GPU due to cores=2")
7071
def test_multichain_plots():
7172
model = build_disaster_model()
@@ -118,3 +119,16 @@ def test_plots_transformed():
118119
assert autocorrplot(trace, plot_transformed=True).shape == (2, 2)
119120
assert plot_posterior(trace).numCols == 1
120121
assert plot_posterior(trace, plot_transformed=True).shape == (2, )
122+
123+
def test_pairplot():
124+
with pm.Model() as model:
125+
a = pm.Normal('a', shape=2)
126+
c = pm.HalfNormal('c', shape=2)
127+
b = pm.Normal('b', a, c, shape=2)
128+
d = pm.Normal('d', 100, 1)
129+
trace = pm.sample(1000)
130+
131+
pairplot(trace)
132+
pairplot(trace, hexbin=True, plot_transformed=True)
133+
pairplot(trace, sub_varnames=['a_0', 'c_0', 'b_1'])
134+

0 commit comments

Comments
 (0)