Skip to content

Commit 20057e4

Browse files
denadai2Junpeng Lao
authored and
Junpeng Lao
committed
Small refactoring of sample_gp (#2373)
* Small refactoring of sample_gp * removed check model is null from sampling
1 parent ee19480 commit 20057e4

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

pymc3/gp/gp.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
__all__ = ['GP', 'sample_gp']
1515

16+
1617
class GP(Continuous):
1718
"""Gausian process
1819
@@ -76,7 +77,8 @@ def logp(self, Y, X=None):
7677
return MvNormal.dist(mu, Sigma).logp(Y)
7778

7879

79-
def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, random_seed=None, progressbar=True, chol_const=True):
80+
def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, random_seed=None, progressbar=True,
81+
chol_const=True):
8082
"""Generate samples from a posterior Gaussian process.
8183
8284
Parameters
@@ -106,18 +108,17 @@ def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, ran
106108
-------
107109
Array of samples from posterior GP evaluated at Z.
108110
"""
109-
model = modelcontext(model)
110-
111111
if samples is None:
112112
samples = len(trace)
113113

114+
model = modelcontext(model)
115+
114116
if random_seed:
115117
np.random.seed(random_seed)
116118

119+
indices = np.random.randint(0, len(trace), samples)
117120
if progressbar:
118-
indices = tqdm(np.random.randint(0, len(trace), samples), total=samples)
119-
else:
120-
indices = np.random.randint(0, len(trace), samples)
121+
indices = tqdm(indices, total=samples)
121122

122123
K = gp.distribution.K
123124

@@ -134,11 +135,13 @@ def sample_gp(trace, gp, X_values, samples=None, obs_noise=True, model=None, ran
134135
else:
135136
S_inv = matrix_inverse(K(X))
136137

138+
S_xz_S_inv = tt.dot(S_xz.T, S_inv)
137139
# Posterior mean
138-
m_post = tt.dot(tt.dot(S_xz.T, S_inv), Y)
140+
m_post = tt.dot(S_xz_S_inv, Y)
139141
# Posterior covariance
140-
S_post = S_zz - tt.dot(tt.dot(S_xz.T, S_inv), S_xz)
142+
S_post = S_zz - tt.dot(S_xz_S_inv, S_xz)
141143

144+
correction = 0
142145
if chol_const:
143146
n = S_post.shape[0]
144147
correction = 1e-6 * tt.nlinalg.trace(S_post) * tt.eye(n)

pymc3/sampling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,7 @@ def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
526526
if samples is None:
527527
samples = len(trace)
528528

529-
if model is None:
530-
model = modelcontext(model)
529+
model = modelcontext(model)
531530

532531
if vars is None:
533532
vars = model.observed_RVs

0 commit comments

Comments
 (0)