Skip to content

Commit ad59074

Browse files
fix & specify type and shapes for plot_gp_dist (#3913)
* fix & specify type and shapes for plot_gp_dist * warn user about nan samples closes #3917 * test that UserWarning is triggered when some samples are nan
1 parent 18f1e51 commit ad59074

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

pymc3/gp/util.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from scipy.cluster.vq import kmeans
1616
import numpy as np
1717
import theano.tensor as tt
18+
import warnings
1819

1920
cholesky = tt.slinalg.cholesky
2021
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular')
@@ -83,17 +84,19 @@ def setter(self, val):
8384
return gp_wrapper
8485

8586

86-
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
87+
def plot_gp_dist(ax, samples:np.ndarray, x:np.ndarray, plot_samples=True, palette="Reds", fill_alpha=0.8, samples_alpha=0.1, fill_kwargs=None, samples_kwargs=None):
8788
""" A helper function for plotting 1D GP posteriors from trace
8889
8990
Parameters
9091
----------
9192
ax: axes
9293
Matplotlib axes.
93-
samples: trace or list of traces
94-
Trace(s) or posterior predictive sample from a GP.
95-
x: array
94+
samples: numpy.ndarray
95+
Array of S posterior predictive sample from a GP.
96+
Expected shape: (S, X)
97+
x: numpy.ndarray
9698
Grid of X values corresponding to the samples.
99+
Expected shape: (X,) or (X, 1), or (1, X)
97100
plot_samples: bool
98101
Plot the GP samples along with posterior (defaults True).
99102
palette: str
@@ -118,6 +121,12 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", fill_alpha=0
118121
fill_kwargs = {}
119122
if samples_kwargs is None:
120123
samples_kwargs = {}
124+
if np.any(np.isnan(samples)):
125+
warnings.warn(
126+
'There are `nan` entries in the [samples] arguments. '
127+
'The plot will not contain a band!',
128+
UserWarning
129+
)
121130

122131
cmap = plt.get_cmap(palette)
123132
percs = np.linspace(51, 99, 40)

pymc3/tests/test_gp.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def test_inv_rightprod(self):
247247
with pytest.raises(ValueError, match=r"cannot combine"):
248248
cov = M + pm.gp.cov.ExpQuad(1, 1.)
249249

250+
250251
class TestCovExponentiation:
251252
def test_symexp_cov(self):
252253
X = np.linspace(0, 1, 10)[:, None]
@@ -539,6 +540,7 @@ def test_1d(self):
539540
Kd = theano.function([],cov(X, diag=True))()
540541
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)
541542

543+
542544
class TestCosine:
543545
def test_1d(self):
544546
X = np.linspace(0, 1, 10)[:, None]
@@ -1142,3 +1144,36 @@ def testMarginalKronRaises(self):
11421144
cov_funcs=self.cov_funcs)
11431145
with pytest.raises(TypeError):
11441146
gp1 + gp2
1147+
1148+
1149+
class TestUtil:
1150+
def test_plot_gp_dist(self):
1151+
"""Test that the plotting helper works with the stated input shapes."""
1152+
import matplotlib.pyplot as plt
1153+
X = 100
1154+
S = 500
1155+
fig, ax = plt.subplots()
1156+
pm.gp.util.plot_gp_dist(
1157+
ax,
1158+
x=np.linspace(0, 50, X),
1159+
samples=np.random.normal(np.arange(X), size=(S, X))
1160+
)
1161+
plt.close()
1162+
pass
1163+
1164+
def test_plot_gp_dist_warn_nan(self):
1165+
"""Test that the plotting helper works with the stated input shapes."""
1166+
import matplotlib.pyplot as plt
1167+
X = 100
1168+
S = 500
1169+
samples = np.random.normal(np.arange(X), size=(S, X))
1170+
samples[15, 3] = np.nan
1171+
fig, ax = plt.subplots()
1172+
with pytest.warns(UserWarning):
1173+
pm.gp.util.plot_gp_dist(
1174+
ax,
1175+
x=np.linspace(0, 50, X),
1176+
samples=samples
1177+
)
1178+
plt.close()
1179+
pass

0 commit comments

Comments
 (0)