Skip to content

Commit 90f20a2

Browse files
authored
Add multi-output support to GP Latent (#7471)
* Port 7226 and add dims support * Fix typo in HSGP prior method > Co-authored-by: hchen19 <[email protected]> * Fix HSGP test
1 parent d313012 commit 90f20a2

File tree

4 files changed

+141
-23
lines changed

4 files changed

+141
-23
lines changed

pymc/gp/gp.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,37 @@ class Latent(Base):
148148
def __init__(self, *, mean_func=Zero(), cov_func=Constant(0.0)):
149149
super().__init__(mean_func=mean_func, cov_func=cov_func)
150150

151-
def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
151+
def _build_prior(
152+
self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs
153+
):
152154
mu = self.mean_func(X)
153155
cov = stabilize(self.cov_func(X), jitter)
154156
if reparameterize:
155-
size = np.shape(X)[0]
156-
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
157-
f = pm.Deterministic(name, mu + cholesky(cov).dot(v), dims=kwargs.get("dims", None))
157+
if "dims" in kwargs:
158+
v = pm.Normal(
159+
name + "_rotated_",
160+
mu=0.0,
161+
sigma=1.0,
162+
**kwargs,
163+
)
164+
165+
else:
166+
size = (n_outputs, X.shape[0]) if n_outputs > 1 else X.shape[0]
167+
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
168+
169+
f = pm.Deterministic(
170+
name,
171+
mu + cholesky(cov).dot(v.T).transpose(),
172+
dims=kwargs.get("dims", None),
173+
)
174+
158175
else:
159-
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
176+
mu_stack = pt.stack([mu] * n_outputs, axis=0) if n_outputs > 1 else mu
177+
f = pm.MvNormal(name, mu=mu_stack, cov=cov, **kwargs)
178+
160179
return f
161180

162-
def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
181+
def prior(self, name, X, n_outputs=1, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
163182
R"""
164183
Returns the GP prior distribution evaluated over the input
165184
locations `X`.
@@ -178,6 +197,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
178197
X : array-like
179198
Function input values. If one-dimensional, must be a column
180199
vector with shape `(n, 1)`.
200+
n_outputs : int, default 1
201+
Number of output GPs. If you're using `dims`, make sure their size
202+
is equal to `(n_outputs, X.shape[0])`, i.e the number of output GPs
203+
by the number of input points.
204+
Example: `gp.prior("f", X=X, n_outputs=3, dims=("n_gps", "x_dim"))`,
205+
where `len(n_gps) = 3` and `len(x_dim = X.shape[0]`.
181206
reparameterize : bool, default True
182207
Reparameterize the distribution by rotating the random
183208
variable by the Cholesky factor of the covariance matrix.
@@ -188,10 +213,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
188213
Extra keyword arguments that are passed to :class:`~pymc.MvNormal`
189214
distribution constructor.
190215
"""
216+
f = self._build_prior(name, X, n_outputs, reparameterize, jitter, **kwargs)
191217

192-
f = self._build_prior(name, X, reparameterize, jitter, **kwargs)
193218
self.X = X
194219
self.f = f
220+
self.n_outputs = n_outputs
221+
195222
return f
196223

197224
def _get_given_vals(self, given):
@@ -212,12 +239,16 @@ def _get_given_vals(self, given):
212239
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, jitter):
213240
Kxx = cov_total(X)
214241
Kxs = self.cov_func(X, Xnew)
242+
215243
L = cholesky(stabilize(Kxx, jitter))
216244
A = solve_lower(L, Kxs)
217-
v = solve_lower(L, f - mean_total(X))
218-
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
245+
v = solve_lower(L, (f - mean_total(X)).T)
246+
247+
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T
248+
219249
Kss = self.cov_func(Xnew)
220250
cov = Kss - pt.dot(pt.transpose(A), A)
251+
221252
return mu, cov
222253

223254
def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
@@ -255,7 +286,9 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
255286
"""
256287
givens = self._get_given_vals(given)
257288
mu, cov = self._build_conditional(Xnew, *givens, jitter)
258-
return pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
289+
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
290+
291+
return f
259292

260293

261294
@conditioned_vars(["X", "f", "nu"])
@@ -447,7 +480,15 @@ def _build_marginal_likelihood(self, X, noise_func, jitter):
447480
return mu, stabilize(cov, jitter)
448481

449482
def marginal_likelihood(
450-
self, name, X, y, sigma=None, noise=None, jitter=JITTER_DEFAULT, is_observed=True, **kwargs
483+
self,
484+
name,
485+
X,
486+
y,
487+
sigma=None,
488+
noise=None,
489+
jitter=JITTER_DEFAULT,
490+
is_observed=True,
491+
**kwargs,
451492
):
452493
R"""
453494
Returns the marginal likelihood distribution, given the input
@@ -529,21 +570,28 @@ def _build_conditional(
529570
Kxs = self.cov_func(X, Xnew)
530571
Knx = noise_func(X)
531572
rxx = y - mean_total(X)
573+
532574
L = cholesky(stabilize(Kxx, jitter) + Knx)
533575
A = solve_lower(L, Kxs)
534-
v = solve_lower(L, rxx)
535-
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v)
576+
v = solve_lower(L, rxx.T)
577+
mu = self.mean_func(Xnew) + pt.dot(pt.transpose(A), v).T
578+
536579
if diag:
537580
Kss = self.cov_func(Xnew, diag=True)
538581
var = Kss - pt.sum(pt.square(A), 0)
582+
539583
if pred_noise:
540584
var += noise_func(Xnew, diag=True)
585+
541586
return mu, var
587+
542588
else:
543589
Kss = self.cov_func(Xnew)
544590
cov = Kss - pt.dot(pt.transpose(A), A)
591+
545592
if pred_noise:
546593
cov += noise_func(Xnew)
594+
547595
return mu, cov if pred_noise else stabilize(cov, jitter)
548596

549597
def conditional(

pymc/gp/hsgp_approx.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,23 @@ def prior(
442442
Dimension name for the GP random variable.
443443
"""
444444
phi, sqrt_psd = self.prior_linearized(X)
445+
self._sqrt_psd = sqrt_psd
445446

446447
if self._parametrization == "noncentered":
447448
self._beta = pm.Normal(
448-
f"{name}_hsgp_coeffs_",
449-
size=self._m_star - int(self._drop_first),
449+
f"{name}_hsgp_coeffs",
450+
size=self.n_basis_vectors - int(self._drop_first),
450451
dims=hsgp_coeffs_dims,
451452
)
452-
self._sqrt_psd = sqrt_psd
453453
f = self.mean_func(X) + phi @ (self._beta * self._sqrt_psd)
454454

455455
elif self._parametrization == "centered":
456-
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", sigma=sqrt_psd, dims=hsgp_coeffs_dims)
456+
self._beta = pm.Normal(
457+
f"{name}_hsgp_coeffs",
458+
sigma=sqrt_psd,
459+
size=self.n_basis_vectors - int(self._drop_first),
460+
dims=hsgp_coeffs_dims,
461+
)
457462
f = self.mean_func(X) + phi @ self._beta
458463

459464
self.f = pm.Deterministic(name, f, dims=gp_dims)

tests/gp/test_gp.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import numpy.testing as npt
20+
import pytensor.tensor as pt
2021
import pytest
2122

2223
import pymc as pm
@@ -90,7 +91,12 @@ def test_raise_value_error(self):
9091
with self.model:
9192
with pytest.raises(ValueError):
9293
self.gp.marginal_likelihood(
93-
"like_both", X=self.x, Xu=self.xu, y=self.y, noise=self.sigma, sigma=self.sigma
94+
"like_both",
95+
X=self.x,
96+
Xu=self.xu,
97+
y=self.y,
98+
noise=self.sigma,
99+
sigma=self.sigma,
94100
)
95101

96102
with pytest.raises(ValueError):
@@ -177,7 +183,11 @@ def setup_method(self):
177183
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
178184
pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3]),
179185
)
180-
self.means = (pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5), pm.gp.mean.Constant(0.5))
186+
self.means = (
187+
pm.gp.mean.Constant(0.5),
188+
pm.gp.mean.Constant(0.5),
189+
pm.gp.mean.Constant(0.5),
190+
)
181191

182192
def testAdditiveMarginal(self):
183193
with pm.Model() as model1:
@@ -199,7 +209,9 @@ def testAdditiveMarginal(self):
199209

200210
with model1:
201211
fp1 = gpsum.conditional(
202-
"fp1", self.Xnew, given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum}
212+
"fp1",
213+
self.Xnew,
214+
given={"X": self.X, "y": self.y, "sigma": self.noise, "gp": gpsum},
203215
)
204216
with model2:
205217
fp2 = gptot.conditional("fp2", self.Xnew)
@@ -230,7 +242,9 @@ def testAdditiveMarginalApprox(self, approx):
230242

231243
with pm.Model() as model2:
232244
gptot = pm.gp.MarginalApprox(
233-
mean_func=reduce(add, self.means), cov_func=reduce(add, self.covs), approx=approx
245+
mean_func=reduce(add, self.means),
246+
cov_func=reduce(add, self.covs),
247+
approx=approx,
234248
)
235249
fsum = gptot.marginal_likelihood("f", self.X, Xu, self.y, sigma=sigma)
236250
model2_logp = model2.compile_logp()({})
@@ -352,6 +366,53 @@ def testLatent2(self):
352366
latent_logp = model.compile_logp()({"f_rotated_": y_rotated, "p": self.pnew})
353367
npt.assert_allclose(latent_logp, self.logp, atol=5)
354368

369+
def testLatentMultioutput(self):
370+
n_outputs = 2
371+
X = np.random.randn(20, 3)
372+
y = np.random.randn(n_outputs, 20)
373+
Xnew = np.random.randn(30, 3)
374+
pnew = np.random.randn(n_outputs, 30)
375+
376+
with pm.Model() as latent_model:
377+
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
378+
mean_func = pm.gp.mean.Constant(0.5)
379+
latent_gp = pm.gp.Latent(mean_func=mean_func, cov_func=cov_func)
380+
latent_f = latent_gp.prior("f", X, n_outputs=n_outputs, reparameterize=True)
381+
latent_p = latent_gp.conditional("p", Xnew)
382+
383+
with pm.Model() as marginal_model:
384+
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
385+
mean_func = pm.gp.mean.Constant(0.5)
386+
marginal_gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
387+
marginal_f = marginal_gp.marginal_likelihood("f", X, y, sigma=0.0)
388+
marginal_p = marginal_gp.conditional("p", Xnew)
389+
390+
assert tuple(latent_f.shape.eval()) == tuple(marginal_f.shape.eval()) == y.shape
391+
assert tuple(latent_p.shape.eval()) == tuple(marginal_p.shape.eval()) == pnew.shape
392+
393+
chol = np.linalg.cholesky(cov_func(X).eval())
394+
v = np.linalg.solve(chol, (y - 0.5).T)
395+
A = np.linalg.solve(chol, cov_func(X, Xnew).eval()).T
396+
mu_cond = mean_func(Xnew).eval() + (A @ v).T
397+
cov_cond = cov_func(Xnew, Xnew).eval() - A @ A.T
398+
399+
with pm.Model() as numpy_model:
400+
numpy_p = pm.MvNormal.dist(mu=pt.as_tensor(mu_cond), cov=pt.as_tensor(cov_cond))
401+
402+
latent_rv_logp = pm.logp(latent_p, pnew)
403+
marginal_rv_logp = pm.logp(marginal_p, pnew)
404+
numpy_rv_logp = pm.logp(numpy_p, pnew)
405+
406+
assert (
407+
latent_rv_logp.shape.eval()
408+
== marginal_rv_logp.shape.eval()
409+
== numpy_rv_logp.shape.eval()
410+
)
411+
412+
npt.assert_allclose(latent_rv_logp.eval(), marginal_rv_logp.eval(), atol=5)
413+
npt.assert_allclose(latent_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)
414+
npt.assert_allclose(marginal_rv_logp.eval(), numpy_rv_logp.eval(), atol=5)
415+
355416

356417
class TestTP:
357418
R"""
@@ -486,7 +547,11 @@ def setup_method(self):
486547
self.X = cartesian(*self.Xs)
487548
self.N = np.prod([len(X) for X in self.Xs])
488549
self.y = np.random.randn(self.N) * 0.1
489-
self.Xnews = (np.random.randn(5, 1), np.random.randn(5, 1), np.random.randn(5, 1))
550+
self.Xnews = (
551+
np.random.randn(5, 1),
552+
np.random.randn(5, 1),
553+
np.random.randn(5, 1),
554+
)
490555
self.Xnew = np.concatenate(self.Xnews, axis=1)
491556
self.sigma = 0.2
492557
self.pnew = np.random.randn(len(self.Xnew))

tests/gp/test_hsgp_approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
186186
gp = pm.gp.HSGP(m=[n_basis], c=4.0, cov_func=cov_func, drop_first=drop_first)
187187
gp.prior("f1", X1)
188188

189-
n_coeffs = model.f1_hsgp_coeffs_.type.shape[0]
189+
n_coeffs = model.f1_hsgp_coeffs.type.shape[0]
190190
if drop_first:
191191
assert (
192192
n_coeffs == n_basis - 1

0 commit comments

Comments
 (0)