-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Let plot_posterior_predictive_glm work with inferencedata too #4234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
cbab50c
2eeb5ae
22991c9
3eb3410
5c37ef7
58d0839
c145b0a
e28f2b8
08c11fc
987e170
f93da72
40ac14b
9540abb
a0fa02b
258018a
9147b34
ddcebeb
fa7687f
2ceaa13
9cf211a
63a5a85
0106d77
3fe8210
9e012fb
2a38f67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
pass | ||
import numpy as np | ||
|
||
from pymc3.backends.base import MultiTrace | ||
|
||
|
||
def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwargs): | ||
"""Plot posterior predictive of a linear model. | ||
|
@@ -47,10 +49,23 @@ def plot_posterior_predictive_glm(trace, eval=None, lm=None, samples=30, **kwarg | |
if "c" not in kwargs and "color" not in kwargs: | ||
kwargs["c"] = "k" | ||
|
||
plotting_fn = _plot_multitrace if isinstance(trace, MultiTrace) else _plot_inferencedata | ||
plotting_fn(trace, eval, lm, samples, kwargs) | ||
plt.title("Posterior predictive") | ||
|
||
|
||
def _plot_multitrace(trace, eval, lm, samples, kwargs): | ||
for rand_loc in np.random.randint(0, len(trace), samples): | ||
rand_sample = trace[rand_loc] | ||
plt.plot(eval, lm(eval, rand_sample), **kwargs) | ||
# Make sure to not plot label multiple times | ||
kwargs.pop("label", None) | ||
|
||
plt.title("Posterior predictive") | ||
|
||
def _plot_inferencedata(trace, eval, lm, samples, kwargs): | ||
trace_df = trace.posterior.to_dataframe() | ||
for rand_loc in np.random.randint(0, len(trace_df), samples): | ||
rand_sample = trace_df.iloc[rand_loc] | ||
plt.plot(eval, lm(eval, rand_sample), **kwargs) | ||
# Make sure to not plot label multiple times | ||
kwargs.pop("label", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two functions have a lot of duplicated lines; I think they can be merged into one by checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we cast to > /home/mgorelli/pymc3-dev/pymc3/plots/posteriorplot.py(61)_plot_multitrace()
-> plt.plot(eval, lm(eval, rand_sample), **kwargs)
(Pdb) type(rand_sample)
<class 'dict'>
(Pdb) rand_sample
{'x': 1.0, 'Intercept': 1.0} at this point, the only lines they have in common are
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah right, I forgot the whole trace was given here, and not only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure - I think this'd be slightly more expensive, but arguably it's worth it for the sake of much simpler code |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import matplotlib.pyplot as plt | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import numpy as np | ||
import pytest | ||
|
||
from arviz import from_pymc3 | ||
|
||
import pymc3 as pm | ||
|
||
from pymc3.backends.ndarray import point_list_to_multitrace | ||
from pymc3.plots import plot_posterior_predictive_glm | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize("inferencedata", [True, False]) | ||
def test_plot_posterior_predictive_glm_defaults(inferencedata): | ||
with pm.Model() as model: | ||
pm.Normal("x") | ||
pm.Normal("Intercept") | ||
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) | ||
if inferencedata: | ||
trace = from_pymc3(trace, model=model) | ||
_, ax = plt.subplots() | ||
plot_posterior_predictive_glm(trace, samples=1) | ||
lines = ax.get_lines() | ||
expected_xvalues = np.linspace(0, 1, 100) | ||
expected_yvalues = np.linspace(1, 2, 100) | ||
for line in lines: | ||
x_axis, y_axis = line.get_data() | ||
np.testing.assert_array_equal(x_axis, expected_xvalues) | ||
np.testing.assert_array_equal(y_axis, expected_yvalues) | ||
assert line.get_lw() == 0.2 | ||
assert line.get_c() == "k" | ||
|
||
|
||
@pytest.mark.parametrize("inferencedata", [True, False]) | ||
def test_plot_posterior_predictive_glm_non_defaults(inferencedata): | ||
with pm.Model() as model: | ||
pm.Normal("x") | ||
pm.Normal("Intercept") | ||
trace = point_list_to_multitrace([{"x": np.array([1]), "Intercept": np.array([1])}], model) | ||
if inferencedata: | ||
trace = from_pymc3(trace, model=model) | ||
_, ax = plt.subplots() | ||
plot_posterior_predictive_glm( | ||
trace, samples=1, lm=lambda x, _: x, eval=np.linspace(0, 1, 10), lw=0.3, c="b" | ||
) | ||
lines = ax.get_lines() | ||
expected_xvalues = np.linspace(0, 1, 10) | ||
expected_yvalues = np.linspace(0, 1, 10) | ||
for line in lines: | ||
x_axis, y_axis = line.get_data() | ||
np.testing.assert_array_equal(x_axis, expected_xvalues) | ||
np.testing.assert_array_equal(y_axis, expected_yvalues) | ||
assert line.get_lw() == 0.3 | ||
assert line.get_c() == "b" |
Uh oh!
There was an error while loading. Please reload this page.