Skip to content

Add scatterplot function #2861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 26, 2018
Merged

Conversation

agustinaarroyuelo
Copy link
Contributor

@agustinaarroyuelo agustinaarroyuelo commented Feb 15, 2018

This function allows to plot scatter matrices of the sampled parameters, or a subset of them. Additionally, it can display divergences. I updated the "Diagnosing biased Inference with Divergences" notebook with examples using this new feature. I am looking forward to receive feedback.

@@ -6,3 +6,4 @@
from .traceplot import traceplot
from .energyplot import energyplot
from .densityplot import densityplot
from .scatterplot import scatterplot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add carriage return at the end of the line

traces['_'.join([v, str(i)])] = vi
else:
traces[v] = vals
return traces
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add carriage return at end of line

@junpenglao
Copy link
Member

This looks quite nice.
I think scatterplot might be a bit confusing with native matplotlib function. Maybe a more meaningful name such as pairvarplot? As it plot pairs of variables?

ax[j, 0].tick_params(labelsize=text_size)

plt.tight_layout()
return ax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add carriage return

@fonnesbeck
Copy link
Member

This looks really good.


if varnames is None:
if plot_transformed:
trace_dict = get_trace_dict(trace, get_default_varnames(trace.varnames, True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to plot only the free_RVs if plot_transformed=True, otherwise you will have a plot showing only the transformation (e.g., tau and tau_log__, which would essentially be redundant).

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks really nice! Do you have an example you can post in the comments?

I made a few suggestions for simplifying the code, but matplotlib can be inscrutable, so there's a good chance they are bad suggestions.

if np.any(divergent):
ax.scatter(trace_dict[varnames[0]][divergent == 1], trace_dict[varnames[1]][divergent == 1], **kwargs_divergence)
else:
print('No divergences were found.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this whole comment also applies to the similar block below)

could you remove this print statement?

I think this block might be nicer as just

try:
    divergent = trace['diverging']
except KeyError:
    warnings.warn(...)
    return ax
diverge = (divergent == 1)
ax.scatter(trace_dict[varnames[0]][diverge], trace_dict[varnames[1]][diverge], **kwargs_divergence)

The scatter should then just be empty if there are no divergences, which I think is fine (the other alternative would be something like putting text on the plot saying "there are no divergences found").

warnings.warn('There is no divergence information in the passed trace.')
return ax

if ax is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating default axes should be above, right after defining numvars. Then you also would not have to special case numvars==2.

return ax

if ax is None:
_, ax = plt.subplots(nrows=numvars, ncols=numvars, figsize=figsize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you do nrows=numvars - 1, ncols=numvars - 1, then the two loops below would be

for i in range(numvars):
    for j in range(i + 1, numvars):
        ...

and you can remove the if i == j block

(this might be totally off base, if axes.remove() does not do what I think it does).

@junpenglao
Copy link
Member

Looking at the example it seems you need to set plot_transformed=True when plotting transformed RVs even when you specify the name in sub_varnames.

pm.scatterplot(short_trace,
               sub_varnames=['theta_0', 'tau_log__'], 
               divergences=True, 
               plot_transformed=True,
               color='red', figsize=(15, 10), kwargs_divergence={'color':'green'})

Is there a way to optimized it? Since plot_transformed=True becomes a bit redundent when you already supply the varname.

Also somewhat related is that, while I think using names like theta_0 is nice for non-scaler RVs, it is not always intuitive for users - what users interact and specifying when varname is involved, they usually don't need to think about which index to use - as pymc3 just handle it internally. In fact, currently if you pass sub_varnames=['theta', 'tau'] as argument the function doesn't work. I wonder if there is a better way to handle this?

@aloctavodia
Copy link
Member

I think scatterplot might be a bit confusing with native matplotlib function. Maybe a more meaningful name such as pairvarplot? As it plot pairs of variables?

I agree, pairvarplot is ok. Maybe something shorter would be better, like pairplot

@fonnesbeck
Copy link
Member

I'm not keen on pairvarplot. Its not a tremendously clear name. I don't think the name collision is a big deal -- that's what namespaces are for.

@agustinaarroyuelo
Copy link
Contributor Author

Thanks everyone for your insightful comments. I am taking every suggestion into account for my next commits.

I think pairplot is a better suited name than scatterplot, because this function includes hexbin plot, which is not strictly a scatter plot.

@ColCarroll Here are some examples:
pairplot_example3
pairplot_example2

Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this :-)


if divergences:
try:
divergent = trace['diverging']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cleaner to use trace.get_sampler_stats, just in case there is a var named 'diverging'.

if np.any(divergent):
ax[j, i].scatter(var1[divergent == 1], var2[divergent == 1], **kwargs_divergence)
else:
print('No divergences were found.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no print. Do we even need to do anything here?

Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow github lost my comments the first time.
In general, you can also try to limit the line length to 80

If True draws an hexbin plot
plot_transformed : bool
Flag for plotting automatically transformed variables in addition to
original variables (defaults to False). Applies when varnames/sub_varnames = None.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line lengths

try:
divergent = trace['diverging']
if np.any(divergent):
ax[j, i].scatter(var1[divergent == 1], var2[divergent == 1], **kwargs_divergence)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the actual position of the divergence here. They are buried in the warnings, see
https://github.com/pymc-devs/pymc3/blob/master/docs/source/notebooks/Diagnosing_biased_Inference_with_Divergences.ipynb
I don't want to expose this to users, but using it here seems fine.

@agustinaarroyuelo
Copy link
Contributor Author

Thanks you for your suggestions. I applied them in every case that was possible.

@twiecki
Copy link
Member

twiecki commented Feb 22, 2018

Looks like you need to rebase.

@@ -0,0 +1,162 @@
import warnings
import matplotlib.gridspec as gridspec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should be inside the try-except block below. Additionally you can change it to from matplotlib import gridspec

Text size for labels
gs : Grid spec
Matplotlib Grid spec.
ax: axes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing ax and gs could be confusing, could we just use gs? I see the ax is there for the special case of plotting a two variable pairplot, but the same effect can be achieved using just gs, right?

Returns
-------

ax : matplotlib axes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should return a gridspec object

ax : matplotlib axes

"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove blank line, try running autopep8 to fix all this style issues.

@agustinaarroyuelo agustinaarroyuelo changed the title Add new scatterplot function Add scatterplot function Feb 24, 2018
@aloctavodia
Copy link
Member

LGTM!

@ColCarroll ColCarroll merged commit 553f057 into pymc-devs:master Feb 26, 2018
@ColCarroll
Copy link
Member

Thanks, @agustinaarroyuelo!

@agustinaarroyuelo agustinaarroyuelo deleted the scatterplot branch February 26, 2018 15:14
@fonnesbeck
Copy link
Member

Thanks for the contribution, Agustina!

@agustinaarroyuelo
Copy link
Contributor Author

resolves #2745

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants