Skip to content

Commit 22d24af

Browse files
committed
GP: use theano cholesky op, stabilize only once, delegate cholesky to Mv
1 parent 81773f3 commit 22d24af

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

pymc3/gp/gp.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from pymc3.gp.cov import Covariance, Constant
99
from pymc3.gp.mean import Zero
1010
from pymc3.gp.util import (conditioned_vars,
11-
infer_shape, stabilize, cholesky, solve_lower, solve_upper)
11+
infer_shape, stabilize, solve_lower, solve_upper)
1212
from pymc3.distributions import draw_values
1313
from pymc3.distributions.dist_math import eigh
1414
from ..math import cartesian, kron_dot, kron_diag
1515

1616
__all__ = ['Latent', 'Marginal', 'TP', 'MarginalSparse', 'MarginalKron']
1717

18+
cholesky = tt.slinalg.cholesky
1819

1920
class Base(object):
2021
R"""
@@ -107,13 +108,13 @@ def __init__(self, mean_func=Zero(), cov_func=Constant(0.0)):
107108

108109
def _build_prior(self, name, X, reparameterize=True, **kwargs):
109110
mu = self.mean_func(X)
110-
chol = cholesky(stabilize(self.cov_func(X)))
111+
cov = stabilize(self.cov_func(X))
111112
shape = infer_shape(X, kwargs.pop("shape", None))
112113
if reparameterize:
113114
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=shape, **kwargs)
114-
f = pm.Deterministic(name, mu + tt.dot(chol, v))
115+
f = pm.Deterministic(name, mu + cholesky(cov).dot(v))
115116
else:
116-
f = pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
117+
f = pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
117118
return f
118119

119120
def prior(self, name, X, reparameterize=True, **kwargs):
@@ -203,9 +204,8 @@ def conditional(self, name, Xnew, given=None, **kwargs):
203204
"""
204205
givens = self._get_given_vals(given)
205206
mu, cov = self._build_conditional(Xnew, *givens)
206-
chol = cholesky(stabilize(cov))
207207
shape = infer_shape(Xnew, kwargs.pop("shape", None))
208-
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
208+
return pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
209209

210210

211211
@conditioned_vars(["X", "f", "nu"])
@@ -249,14 +249,14 @@ def __add__(self, other):
249249

250250
def _build_prior(self, name, X, reparameterize=True, **kwargs):
251251
mu = self.mean_func(X)
252-
chol = cholesky(stabilize(self.cov_func(X)))
252+
cov = stabilize(self.cov_func(X))
253253
shape = infer_shape(X, kwargs.pop("shape", None))
254254
if reparameterize:
255255
chi2 = pm.ChiSquared("chi2_", self.nu)
256256
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=shape, **kwargs)
257-
f = pm.Deterministic(name, (tt.sqrt(self.nu) / chi2) * (mu + tt.dot(chol, v)))
257+
f = pm.Deterministic(name, (tt.sqrt(self.nu) / chi2) * (mu + cholesky(cov).dot(v)))
258258
else:
259-
f = pm.MvStudentT(name, nu=self.nu, mu=mu, chol=chol, shape=shape, **kwargs)
259+
f = pm.MvStudentT(name, nu=self.nu, mu=mu, cov=cov, shape=shape, **kwargs)
260260
return f
261261

262262
def prior(self, name, X, reparameterize=True, **kwargs):
@@ -321,10 +321,9 @@ def conditional(self, name, Xnew, **kwargs):
321321

322322
X = self.X
323323
f = self.f
324-
nu2, mu, covT = self._build_conditional(Xnew, X, f)
325-
chol = cholesky(stabilize(covT))
324+
nu2, mu, cov = self._build_conditional(Xnew, X, f)
326325
shape = infer_shape(Xnew, kwargs.pop("shape", None))
327-
return pm.MvStudentT(name, nu=nu2, mu=mu, chol=chol, shape=shape, **kwargs)
326+
return pm.MvStudentT(name, nu=nu2, mu=mu, cov=cov, shape=shape, **kwargs)
328327

329328

330329
@conditioned_vars(["X", "y", "noise"])
@@ -418,15 +417,15 @@ def marginal_likelihood(self, name, X, y, noise, is_observed=True, **kwargs):
418417
if not isinstance(noise, Covariance):
419418
noise = pm.gp.cov.WhiteNoise(noise)
420419
mu, cov = self._build_marginal_likelihood(X, noise)
421-
chol = cholesky(stabilize(cov))
420+
cov = stabilize(cov)
422421
self.X = X
423422
self.y = y
424423
self.noise = noise
425424
if is_observed:
426-
return pm.MvNormal(name, mu=mu, chol=chol, observed=y, **kwargs)
425+
return pm.MvNormal(name, mu=mu, cov=cov, observed=y, **kwargs)
427426
else:
428427
shape = infer_shape(X, kwargs.pop("shape", None))
429-
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
428+
return pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
430429

431430
def _get_given_vals(self, given):
432431
if given is None:
@@ -504,9 +503,8 @@ def conditional(self, name, Xnew, pred_noise=False, given=None, **kwargs):
504503

505504
givens = self._get_given_vals(given)
506505
mu, cov = self._build_conditional(Xnew, pred_noise, False, *givens)
507-
chol = cholesky(cov)
508506
shape = infer_shape(Xnew, kwargs.pop("shape", None))
509-
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
507+
return pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
510508

511509
def predict(self, Xnew, point=None, diag=False, pred_noise=False, given=None):
512510
R"""
@@ -797,9 +795,8 @@ def conditional(self, name, Xnew, pred_noise=False, given=None, **kwargs):
797795

798796
givens = self._get_given_vals(given)
799797
mu, cov = self._build_conditional(Xnew, pred_noise, False, *givens)
800-
chol = cholesky(cov)
801798
shape = infer_shape(Xnew, kwargs.pop("shape", None))
802-
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
799+
return pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
803800

804801

805802
@conditioned_vars(["Xs", "y", "sigma"])
@@ -959,7 +956,7 @@ def _build_conditional(self, Xnew, pred_noise, diag):
959956
cov = Km - Asq
960957
if pred_noise:
961958
cov += sigma * np.eye(cov.shape)
962-
return mu, cov
959+
return mu, stabilize(cov)
963960

964961
def conditional(self, name, Xnew, pred_noise=False, **kwargs):
965962
"""
@@ -996,9 +993,8 @@ def conditional(self, name, Xnew, pred_noise=False, **kwargs):
996993
constructor.
997994
"""
998995
mu, cov = self._build_conditional(Xnew, pred_noise, False)
999-
chol = cholesky(stabilize(cov))
1000996
shape = infer_shape(Xnew, kwargs.pop("shape", None))
1001-
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
997+
return pm.MvNormal(name, mu=mu, cov=cov, shape=shape, **kwargs)
1002998

1003999
def predict(self, Xnew, point=None, diag=False, pred_noise=False):
10041000
R"""

0 commit comments

Comments
 (0)