Skip to content

Add scatterplot function #2861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 26, 2018
Merged
4 changes: 2 additions & 2 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

### New features

<<<<<<< refs/remotes/pymc-devs/master
- Add `logit_p` keyword to `pm.Bernoulli`, so that users can specify the logit of the success probability. This is faster and more stable than using `p=tt.nnet.sigmoid(logit_p)`.
- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method which in turn makes sampling from a `DensityDist` possible.
- Effective sample size computation is updated. The estimation uses Geyer's initial positive sequence, which no longer truncates the autocorrelation series inaccurately. `pm.diagnostics.effective_n` now can reports N_eff>N.
- Added `KroneckerNormal` distribution and a corresponding `MarginalKron`
Gaussian Process implementation for efficient inference, along with
lower-level functions such as `cartesian` and `kronecker` products.
- Added `Coregion` covariance function.
- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters.
Optionally it can plot divergences.
- Plots of discrete distributions in the docstrings

### Fixes
Expand All @@ -31,7 +32,6 @@ deviation. This works better for multimodal distributions. Functions using KDE p

- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use
log_i0 to compute the logp.
>>>>>>> Changes in test for random method in DensityDist

### Deprecations

Expand Down
1,250 changes: 623 additions & 627 deletions docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc3/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .traceplot import traceplot
from .energyplot import energyplot
from .densityplot import densityplot
from .pairplot import pairplot
25 changes: 25 additions & 0 deletions pymc3/plots/artists.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from scipy.stats import mode
from collections import OrderedDict

from pymc3.stats import hpd
from .kdeplot import fast_kde, kdeplot
Expand Down Expand Up @@ -144,3 +145,27 @@ def set_key_if_doesnt_exist(d, key, value):
display_ref_val(ref_val)
if rope is not None:
display_rope(rope)

def scale_text(figsize, text_size):
"""Scale text to figsize."""

if text_size is None and figsize is not None:
if figsize[0] <= 11:
return 12
else:
return figsize[0]
else:
return text_size

def get_trace_dict(tr, varnames):
traces = OrderedDict()
for v in varnames:
vals = tr.get_values(v, combine=True, squeeze=True)
if vals.ndim > 1:
vals_flat = vals.reshape(vals.shape[0], -1).T
for i, vi in enumerate(vals_flat):
traces['_'.join([v, str(i)])] = vi
else:
traces[v] = vals
return traces
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add carriage return at end of line


170 changes: 170 additions & 0 deletions pymc3/plots/pairplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import warnings

try:
import matplotlib.pyplot as plt
from matplotlib import gridspec
except ImportError: # mpl is optional
pass
from ..util import get_default_varnames, is_transformed_name, get_untransformed_name
from .artists import get_trace_dict, scale_text


def pairplot(trace, varnames=None, figsize=None, text_size=None,
gs=None, ax=None, hexbin=False, plot_transformed=False,
divergences=False, kwargs_divergence=None,
sub_varnames=None, **kwargs):
"""
Plot a scatter or hexbin matrix of the sampled parameters.

Parameters
----------

trace : result of MCMC run
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
figsize : figure size tuple
If None, size is (8 + numvars, 8 + numvars)
text_size: int
Text size for labels
gs : Grid spec
Matplotlib Grid spec.
ax: axes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing ax and gs could be confusing, could we just use gs? I see the ax is there for the special case of plotting a two variable pairplot, but the same effect can be achieved using just gs, right?

Matplotlib axes
hexbin : Boolean
If True draws an hexbin plot
plot_transformed : bool
Flag for plotting automatically transformed variables in addition to
original variables (defaults to False). Applies when varnames = None.
When a list of varnames is passed, transformed variables can be passed
using their names.
divergences : Boolean
If True divergences will be plotted in a diferent color
kwargs_divergence : dicts, optional
Aditional keywords passed to ax.scatter for divergences
sub_varnames : list
Aditional varnames passed for plotting subsets of multidimensional
variables
Returns
-------

ax : matplotlib axes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should return a gridspec object

gs : matplotlib gridspec

"""
if varnames is None:
if plot_transformed:

varnames_copy = list(trace.varnames)
remove = [get_untransformed_name(var) for var in trace.varnames
if is_transformed_name(var)]

try:
[varnames_copy.remove(i) for i in remove]
varnames = varnames_copy
except ValueError:
varnames = varnames_copy

trace_dict = get_trace_dict(
trace, get_default_varnames(
varnames, plot_transformed))

else:
trace_dict = get_trace_dict(
trace, get_default_varnames(
trace.varnames, plot_transformed))

if sub_varnames is None:
varnames = list(trace_dict.keys())

else:
trace_dict = get_trace_dict(
trace, get_default_varnames(
trace.varnames, True))
varnames = sub_varnames

else:
trace_dict = get_trace_dict(trace, varnames)
varnames = list(trace_dict.keys())

if text_size is None:
text_size = scale_text(figsize, text_size=text_size)

if kwargs_divergence is None:
kwargs_divergence = {}

numvars = len(varnames)

if figsize is None:
figsize = (8 + numvars, 8 + numvars)

if numvars < 2:
raise Exception(
'Number of variables to be plotted must be 2 or greater.')

if numvars == 2 and ax is not None:
if hexbin:
ax.hexbin(trace_dict[varnames[0]],
trace_dict[varnames[1]], mincnt=1, **kwargs)
else:
ax.scatter(trace_dict[varnames[0]],
trace_dict[varnames[1]], **kwargs)

if divergences:
try:
divergent = trace['diverging']
except KeyError:
warnings.warn('No divergences were found.')

diverge = (divergent == 1)
ax.scatter(trace_dict[varnames[0]][diverge],
trace_dict[varnames[1]][diverge], **kwargs_divergence)
ax.set_xlabel('{}'.format(varnames[0]),
fontsize=text_size)
ax.set_ylabel('{}'.format(
varnames[1]), fontsize=text_size)
ax.tick_params(labelsize=text_size)

if gs is None and ax is None:
plt.figure(figsize=figsize)
gs = gridspec.GridSpec(numvars - 1, numvars - 1)

for i in range(0, numvars - 1):
var1 = trace_dict[varnames[i]]

for j in range(i, numvars - 1):
var2 = trace_dict[varnames[j + 1]]

ax = plt.subplot(gs[j, i])

if hexbin:
ax.hexbin(var1, var2, mincnt=1, **kwargs)
else:
ax.scatter(var1, var2, **kwargs)

if divergences:
try:
divergent = trace['diverging']
except KeyError:
warnings.warn('No divergences were found.')
return ax

diverge = (divergent == 1)
ax.scatter(var1[diverge],
var2[diverge],
**kwargs_divergence)

if j + 1 != numvars - 1:
ax.set_xticks([])
else:
ax.set_xlabel('{}'.format(varnames[i]),
fontsize=text_size)
if i != 0:
ax.set_yticks([])
else:
ax.set_ylabel('{}'.format(
varnames[j + 1]), fontsize=text_size)

ax.tick_params(labelsize=text_size)

plt.tight_layout()
return ax, gs
34 changes: 6 additions & 28 deletions pymc3/plots/posteriorplot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections import OrderedDict

try:
import matplotlib.pyplot as plt
except ImportError: # mpl is optional
pass
import numpy as np

from .artists import plot_posterior_op
from .artists import plot_posterior_op, get_trace_dict, scale_text

from .utils import identity_transform, get_default_varnames


Expand Down Expand Up @@ -63,17 +62,6 @@ def plot_posterior(trace, varnames=None, transform=identity_transform, figsize=N

"""

def scale_text(figsize, text_size=text_size):
"""Scale text to figsize."""

if text_size is None and figsize is not None:
if figsize[0] <= 11:
return 12
else:
return figsize[0]
else:
return text_size

def create_axes_grid(figsize, traces):
l_trace = len(traces)
if l_trace == 1:
Expand All @@ -89,27 +77,17 @@ def create_axes_grid(figsize, traces):
ax = ax[:-1]
return fig, ax

def get_trace_dict(tr, varnames):
traces = OrderedDict()
for v in varnames:
vals = tr.get_values(v, combine=True, squeeze=True)
if vals.ndim > 1:
vals_flat = vals.reshape(vals.shape[0], -1).T
for i, vi in enumerate(vals_flat):
traces['_'.join([v, str(i)])] = vi
else:
traces[v] = vals
return traces

if isinstance(trace, np.ndarray):
if figsize is None:
figsize = (6, 2)
if ax is None:
fig, ax = plt.subplots(figsize=figsize)


plot_posterior_op(transform(trace), ax=ax, bw=bw, kde_plot=kde_plot,
point_estimate=point_estimate, round_to=round_to, alpha_level=alpha_level,
ref_val=ref_val, rope=rope, text_size=scale_text(figsize), **kwargs)

else:
if varnames is None:
varnames = get_default_varnames(trace.varnames, plot_transformed)
Expand All @@ -135,8 +113,8 @@ def get_trace_dict(tr, varnames):
plot_posterior_op(tr_values, ax=a, bw=bw, kde_plot=kde_plot,
point_estimate=point_estimate, round_to=round_to,
alpha_level=alpha_level, ref_val=ref_val[idx],
rope=rope[idx], text_size=scale_text(figsize), **kwargs)
a.set_title(v, fontsize=scale_text(figsize))
rope=rope[idx], text_size=scale_text(figsize, text_size), **kwargs)
a.set_title(v, fontsize=scale_text(figsize, text_size))

plt.tight_layout()
return ax
Expand Down
2 changes: 1 addition & 1 deletion pymc3/plots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
pass
import numpy as np
# plotting utilities can all be in this namespace
from ..util import get_default_varnames # pylint: disable=unused-import
from ..util import get_default_varnames # pylint: disable=unused-import


def identity_transform(x):
Expand Down
16 changes: 15 additions & 1 deletion pymc3/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .checks import close_to

from .models import multidimensional_model, simple_categorical
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot
from ..plots import traceplot, forestplot, autocorrplot, plot_posterior, energyplot, densityplot, pairplot
from ..plots.utils import make_2d
from ..step_methods import Slice, Metropolis
from ..sampling import sample
Expand Down Expand Up @@ -66,6 +66,7 @@ def test_plots_multidimensional():
forestplot(trace)
densityplot(trace)


@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on GPU due to cores=2")
def test_multichain_plots():
model = build_disaster_model()
Expand Down Expand Up @@ -118,3 +119,16 @@ def test_plots_transformed():
assert autocorrplot(trace, plot_transformed=True).shape == (2, 2)
assert plot_posterior(trace).numCols == 1
assert plot_posterior(trace, plot_transformed=True).shape == (2, )

def test_pairplot():
with pm.Model() as model:
a = pm.Normal('a', shape=2)
c = pm.HalfNormal('c', shape=2)
b = pm.Normal('b', a, c, shape=2)
d = pm.Normal('d', 100, 1)
trace = pm.sample(1000)

pairplot(trace)
pairplot(trace, hexbin=True, plot_transformed=True)
pairplot(trace, sub_varnames=['a_0', 'c_0', 'b_1'])