@@ -148,18 +148,37 @@ class Latent(Base):
148
148
def __init__ (self , * , mean_func = Zero (), cov_func = Constant (0.0 )):
149
149
super ().__init__ (mean_func = mean_func , cov_func = cov_func )
150
150
151
- def _build_prior (self , name , X , reparameterize = True , jitter = JITTER_DEFAULT , ** kwargs ):
151
+ def _build_prior (
152
+ self , name , X , n_outputs = 1 , reparameterize = True , jitter = JITTER_DEFAULT , ** kwargs
153
+ ):
152
154
mu = self .mean_func (X )
153
155
cov = stabilize (self .cov_func (X ), jitter )
154
156
if reparameterize :
155
- size = np .shape (X )[0 ]
156
- v = pm .Normal (name + "_rotated_" , mu = 0.0 , sigma = 1.0 , size = size , ** kwargs )
157
- f = pm .Deterministic (name , mu + cholesky (cov ).dot (v ), dims = kwargs .get ("dims" , None ))
157
+ if "dims" in kwargs :
158
+ v = pm .Normal (
159
+ name + "_rotated_" ,
160
+ mu = 0.0 ,
161
+ sigma = 1.0 ,
162
+ ** kwargs ,
163
+ )
164
+
165
+ else :
166
+ size = (n_outputs , X .shape [0 ]) if n_outputs > 1 else X .shape [0 ]
167
+ v = pm .Normal (name + "_rotated_" , mu = 0.0 , sigma = 1.0 , size = size , ** kwargs )
168
+
169
+ f = pm .Deterministic (
170
+ name ,
171
+ mu + cholesky (cov ).dot (v .T ).transpose (),
172
+ dims = kwargs .get ("dims" , None ),
173
+ )
174
+
158
175
else :
159
- f = pm .MvNormal (name , mu = mu , cov = cov , ** kwargs )
176
+ mu_stack = pt .stack ([mu ] * n_outputs , axis = 0 ) if n_outputs > 1 else mu
177
+ f = pm .MvNormal (name , mu = mu_stack , cov = cov , ** kwargs )
178
+
160
179
return f
161
180
162
- def prior (self , name , X , reparameterize = True , jitter = JITTER_DEFAULT , ** kwargs ):
181
+ def prior (self , name , X , n_outputs = 1 , reparameterize = True , jitter = JITTER_DEFAULT , ** kwargs ):
163
182
R"""
164
183
Returns the GP prior distribution evaluated over the input
165
184
locations `X`.
@@ -178,6 +197,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
178
197
X : array-like
179
198
Function input values. If one-dimensional, must be a column
180
199
vector with shape `(n, 1)`.
200
+ n_outputs : int, default 1
201
+ Number of output GPs. If you're using `dims`, make sure their size
202
+ is equal to `(n_outputs, X.shape[0])`, i.e the number of output GPs
203
+ by the number of input points.
204
+ Example: `gp.prior("f", X=X, n_outputs=3, dims=("n_gps", "x_dim"))`,
205
+ where `len(n_gps) = 3` and `len(x_dim = X.shape[0]`.
181
206
reparameterize : bool, default True
182
207
Reparameterize the distribution by rotating the random
183
208
variable by the Cholesky factor of the covariance matrix.
@@ -188,10 +213,12 @@ def prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
188
213
Extra keyword arguments that are passed to :class:`~pymc.MvNormal`
189
214
distribution constructor.
190
215
"""
216
+ f = self ._build_prior (name , X , n_outputs , reparameterize , jitter , ** kwargs )
191
217
192
- f = self ._build_prior (name , X , reparameterize , jitter , ** kwargs )
193
218
self .X = X
194
219
self .f = f
220
+ self .n_outputs = n_outputs
221
+
195
222
return f
196
223
197
224
def _get_given_vals (self , given ):
@@ -212,12 +239,16 @@ def _get_given_vals(self, given):
212
239
def _build_conditional (self , Xnew , X , f , cov_total , mean_total , jitter ):
213
240
Kxx = cov_total (X )
214
241
Kxs = self .cov_func (X , Xnew )
242
+
215
243
L = cholesky (stabilize (Kxx , jitter ))
216
244
A = solve_lower (L , Kxs )
217
- v = solve_lower (L , f - mean_total (X ))
218
- mu = self .mean_func (Xnew ) + pt .dot (pt .transpose (A ), v )
245
+ v = solve_lower (L , (f - mean_total (X )).T )
246
+
247
+ mu = self .mean_func (Xnew ) + pt .dot (pt .transpose (A ), v ).T
248
+
219
249
Kss = self .cov_func (Xnew )
220
250
cov = Kss - pt .dot (pt .transpose (A ), A )
251
+
221
252
return mu , cov
222
253
223
254
def conditional (self , name , Xnew , given = None , jitter = JITTER_DEFAULT , ** kwargs ):
@@ -255,7 +286,9 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
255
286
"""
256
287
givens = self ._get_given_vals (given )
257
288
mu , cov = self ._build_conditional (Xnew , * givens , jitter )
258
- return pm .MvNormal (name , mu = mu , cov = cov , ** kwargs )
289
+ f = pm .MvNormal (name , mu = mu , cov = cov , ** kwargs )
290
+
291
+ return f
259
292
260
293
261
294
@conditioned_vars (["X" , "f" , "nu" ])
@@ -447,7 +480,15 @@ def _build_marginal_likelihood(self, X, noise_func, jitter):
447
480
return mu , stabilize (cov , jitter )
448
481
449
482
def marginal_likelihood (
450
- self , name , X , y , sigma = None , noise = None , jitter = JITTER_DEFAULT , is_observed = True , ** kwargs
483
+ self ,
484
+ name ,
485
+ X ,
486
+ y ,
487
+ sigma = None ,
488
+ noise = None ,
489
+ jitter = JITTER_DEFAULT ,
490
+ is_observed = True ,
491
+ ** kwargs ,
451
492
):
452
493
R"""
453
494
Returns the marginal likelihood distribution, given the input
@@ -529,21 +570,28 @@ def _build_conditional(
529
570
Kxs = self .cov_func (X , Xnew )
530
571
Knx = noise_func (X )
531
572
rxx = y - mean_total (X )
573
+
532
574
L = cholesky (stabilize (Kxx , jitter ) + Knx )
533
575
A = solve_lower (L , Kxs )
534
- v = solve_lower (L , rxx )
535
- mu = self .mean_func (Xnew ) + pt .dot (pt .transpose (A ), v )
576
+ v = solve_lower (L , rxx .T )
577
+ mu = self .mean_func (Xnew ) + pt .dot (pt .transpose (A ), v ).T
578
+
536
579
if diag :
537
580
Kss = self .cov_func (Xnew , diag = True )
538
581
var = Kss - pt .sum (pt .square (A ), 0 )
582
+
539
583
if pred_noise :
540
584
var += noise_func (Xnew , diag = True )
585
+
541
586
return mu , var
587
+
542
588
else :
543
589
Kss = self .cov_func (Xnew )
544
590
cov = Kss - pt .dot (pt .transpose (A ), A )
591
+
545
592
if pred_noise :
546
593
cov += noise_func (Xnew )
594
+
547
595
return mu , cov if pred_noise else stabilize (cov , jitter )
548
596
549
597
def conditional (
0 commit comments