Skip to content

Commit 0d34aed

Browse files
committed
attempt at implementing LatenSparse dTC and FITC approximations
could probably use some matrix computation optimization
1 parent e13c4cd commit 0d34aed

File tree

2 files changed

+89
-30
lines changed

2 files changed

+89
-30
lines changed

pymc3/gp/gp.py

Lines changed: 75 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from pymc3.gp.cov import Covariance, Constant
88
from pymc3.gp.mean import Zero
99
from pymc3.gp.util import (conditioned_vars,
10-
infer_shape, stabilize, cholesky, solve_lower, solve_upper)
10+
infer_shape, stabilize, cholesky, solve_lower, solve_upper,
11+
invert_dot, project_inverse)
1112
from pymc3.distributions import draw_values
1213
from pymc3.distributions.dist_math import eigh
1314
from ..math import cartesian, kron_dot, kron_diag
@@ -210,19 +211,27 @@ def conditional(self, name, Xnew, given=None, **kwargs):
210211
@conditioned_vars(["X", "Xu", "f"])
211212
class LatentSparse(Latent):
212213
R"""
213-
Approximate latent Gaussian process.
214+
Approximate latent Gaussian process (GP).
214215
215216
The `gp.LatentSparse` class is a direct implementation of a GP. No addiive
216217
noise is assumed. It is called "Latent" because the underlying function
217218
values are treated as latent variables. It has a `prior` method and a
218219
`conditional` method. Given a mean and covariance function the
219-
function :math:`f(x)` is modeled as,
220+
function :math:`f(X)` is modeled as,
221+
222+
.. math::
223+
224+
f \sim \int p(f \mid u) p(u) \mathrm{d}u
225+
226+
where `u` is the GP function prior on a set of inducing points `Xu`
220227
221228
.. math::
222229
223-
f(x) \sim \mathcal{GP}\left(\mu(x), k(x_u, x_u')\right)
230+
u \mid X_u = \sim \text{MvNormal}\left( \mu(X_u), K(X_u, X_u) \right)
224231
225-
with the inducing points x_u approximating the full covariance.
232+
The DTC and FITC approximations use a simplified `p(f | u) ~ q(f | u)`
233+
and the resulting `f` is a GP prior only
234+
in the case of the FITC approximation.
226235
227236
Use the `prior` and `conditional` methods to actually construct random
228237
variables representing the unknown, or latent, function whose
@@ -235,8 +244,11 @@ class LatentSparse(Latent):
235244
----------
236245
cov_func : None, 2D array, or instance of Covariance
237246
The covariance function. Defaults to zero.
247+
Must implement the `diag` method for the FITC approximation
238248
mean_func : None, instance of Mean
239249
The mean function. Defaults to zero.
250+
approx : str
251+
Approximation to use. One of ['DTC', 'FITC']
240252
241253
Examples
242254
--------
@@ -268,6 +280,7 @@ class LatentSparse(Latent):
268280
fcond = gp.conditional("fcond", Xnew=Xnew)
269281
"""
270282

283+
_available_approx = ("DTC", "FITC")
271284

272285
def __init__(self, mean_func=Zero(), cov_func=Constant(0.0), approx="FITC"):
273286
if approx not in self._available_approx:
@@ -276,27 +289,58 @@ def __init__(self, mean_func=Zero(), cov_func=Constant(0.0), approx="FITC"):
276289
super(Latent, self).__init__(mean_func, cov_func)
277290

278291
def _build_prior(self, name, X, Xu, **kwargs):
279-
mu = self.mean_func(X)
280-
L = cholesky(stabilize(self.cov_func(Xu)))
292+
mu = self.mean_func(X) # (n,)
293+
L = cholesky(stabilize(self.cov_func(Xu))) # (m, m) \sqrt{K_u}
281294
shape = infer_shape(Xu, kwargs.pop("shape", None))
282295
v = pm.Normal(name + "_u_rotated_", mu=0.0, sd=1.0, shape=shape, **kwargs)
283-
u = pm.Deterministic(name+'_u', tt.dot(L, v)) # TODO do we need to keep it?
284-
Kfu = self.cov_func(X, Xu)
285-
Kuiu = solve_upper(L.T, solve_lower(L, u))
286-
f = pm.Deterministic("f", mu + tt.dot(Kfu, Kuiu))
296+
u_ = self.mean_func(Xu) + tt.dot(L, v) # mean + chol method of MvGaussian
297+
u = pm.Deterministic(name+'_u', u_) # (m,) prior at inducing points
298+
Kuuiu = invert_dot(L, u) # (m,) K_{uu}^{-1} u
299+
Kfu = self.cov_func(X, Xu) # (n, m)
300+
f_ = mu + tt.dot(Kfu, Kuuiu) # (n, m) @ (m,) = (n,)
301+
if self.approx == 'DTC':
302+
f = pm.Deterministic("f", f_)
303+
elif self.approx == 'FITC':
304+
Qff_diag = project_inverse(Kfu, L, diag=True)
305+
Kff_diag = self.cov_func.diag(X)
306+
# MvNormal with diagonal cov is Normal with sd=cov**0.5
307+
f = pm.Normal("f", mu=f_, sd=tt.sqrt(tt.clip(Kff_diag - Qff_daig, 0, np.inf)))
287308
return f
288309

289310
def prior(self, name, X, Xu, **kwargs):
290311
R"""
291312
Returns the GP prior distribution evaluated over the input
292-
locations `X`.
313+
locations `X` with inducing locations `Xu`.
293314
294315
This is the prior probability over the space
295316
of functions described by its mean and covariance function.
296317
318+
The DTC and FITC approximations use a simplified form of the true conditional `p(f | u)`
319+
297320
.. math::
321+
f \mid X, X_u \sim \text{MvNormal}\left( K(X, X_u) K(X_u, X_u)^{-1} u, K(X, X) - Q(X, X) \right)
298322
299-
f \mid X, X_u \sim \text{MvNormal}\left( \mu(X), k(X_u, X_u') \right)
323+
where
324+
325+
.. math::
326+
u \mid X_u \sim \text{MvNormal}\left( \mu(X_u), K(X_u, X_u) \right)
327+
328+
and
329+
330+
.. math::
331+
332+
Q(X, X') = K(X, X_u) K(X_u, X_u)^{-1} K(X_u, X')
333+
334+
The DTC approximation uses (resulting in a non-GP prior)
335+
336+
.. math::
337+
K(X, X) - Q(X, X) \approx 0
338+
339+
The FITC approximation uses (resulting in a GP prior)
340+
341+
.. math::
342+
343+
K(X, X) - Q(X, X) \approx \mathrm{diag}(K(X, X) - Q(X, X))
300344
301345
Parameters
302346
----------
@@ -333,30 +377,31 @@ def _get_given_vals(self, given):
333377
return X, Xu, f, cov_total, mean_total
334378

335379
def _build_conditional(self, Xnew, X, Xu, f, cov_total, mean_total):
380+
Kuu = cov_total(Xu)
381+
Luu = cholesky(stabilize(Kuu))
382+
383+
Kuf = cov_total(Xu, X)
384+
Kuffu = tt.dot(Kuf, Kuf.T)
385+
Luffu = cholesky(stabilize(Kuffu))
336386
Ksu = self.cov_func(Xnew, Xu)
337-
L = cholesky(stabilize(cov_total(Xu)))
338-
mus = self.mean_func(Xnew) + tt.dot(Ksu, Kuiu)
339-
#TODO use mean_total?
340-
tmp = solve_lower(L, Ksu.T)
341-
Qss = tt.dot(tmp.T, tmp) #Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
387+
r = f - mean_total(X) # the equations are derived for 0-mean f
388+
Kuuiu = invert_dot(Luffu, tt.dot(Kuf, r))
389+
mus = self.mean_func(Xnew) + tt.dot(Ksu, Kuuiu)
390+
if self.approx == 'FITC':
391+
Lambda = cov_func.diag(X) - project_inverse(Kuf.T, Luu, diag=True)
392+
Qsf = project_inverse(Ksu, Luu, P_T=Kuf)
393+
mus += tt.dot(Qsf, f / Lambda)
394+
Qss = project_inverse(Ksu, Luu)
342395
Kss = self.cov_func(Xnew)
343396
cov = Kss - Qss
397+
if self.approx == 'FITC':
398+
cov -= tt.dot(Qsf, tt.transpose(Qsf / Lambda))
344399
return mus, cov
345400

346401
def conditional(self, name, Xnew, given=None, **kwargs):
347402
R"""
348-
Returns the conditional distribution evaluated over new input
349-
locations `Xnew`.
350-
351-
Given a set of function values `f` that
352-
the GP prior was over, the conditional distribution over a
353-
set of new points, `f_*` is
354-
355-
.. math::
356-
357-
f_* \mid f, X, X_* \sim \mathcal{GP}\left(
358-
K(X_*, X) K(X, X)^{-1} f \,,
359-
K(X_*, X_*) - K(X_*, X) K(X, X)^{-1} K(X, X_*) \right)
403+
Returns the approximate conditional distribution evaluated
404+
over new input locations `Xnew`.
360405
361406
Parameters
362407
----------

pymc3/gp/util.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@
99
solve_upper = tt.slinalg.Solve(A_structure='upper_triangular')
1010
solve = tt.slinalg.Solve(A_structure='general')
1111

12+
def invert_dot(L, X):
13+
"""Wrapper for common pattern K^{-1} @ X where K = L @ L^T"""
14+
return solve_upper(L.T, solve_lower(L, X))
15+
16+
def project_inverse(P, L, diag=True, P_T=None):
17+
"""Wrapper for common pattern P @ K^{-1} @ P^T where K = L @ L^T"""
18+
if P_T is None:
19+
P_T = P.T
20+
if diag:
21+
A = solve_lower(L, P_T)
22+
return tt.sum(A * A, axis=0) # the diagonal of A.T @ A
23+
else:
24+
return tt.dot(P, invert_dot(L, P_T))
25+
1226

1327
def infer_shape(X, n_points=None):
1428
if n_points is None:

0 commit comments

Comments
 (0)