Skip to content

Commit 31150c3

Browse files
bwengalsbwengals
authored and
bwengals
committed
fix TP reparameterization to sample studentt instead of chi2/norm
1 parent d1e09fe commit 31150c3

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

pymc/gp/gp.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pymc.gp.util import (
2727
cholesky,
2828
conditioned_vars,
29-
infer_shape,
29+
infer_size,
3030
replace_with_value,
3131
solve_lower,
3232
solve_upper,
@@ -129,12 +129,13 @@ def __init__(self, *, mean_func=Zero(), cov_func=Constant(0.0)):
129129
def _build_prior(self, name, X, reparameterize=True, **kwargs):
130130
mu = self.mean_func(X)
131131
cov = stabilize(self.cov_func(X))
132-
shape = infer_shape(X, kwargs.pop("shape", None))
133132
if reparameterize:
134-
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs)
133+
size = infer_size(X, kwargs.pop("size", None))
134+
print("_build_prior:size", size)
135+
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
135136
f = pm.Deterministic(name, mu + cholesky(cov).dot(v))
136137
else:
137-
f = pm.MvNormal(name, mu=mu, cov=cov, size=shape, **kwargs)
138+
f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
138139
return f
139140

140141
def prior(self, name, X, reparameterize=True, **kwargs):
@@ -269,13 +270,12 @@ def __add__(self, other):
269270
def _build_prior(self, name, X, reparameterize=True, **kwargs):
270271
mu = self.mean_func(X)
271272
cov = stabilize(self.cov_func(X))
272-
shape = infer_shape(X, kwargs.pop("shape", None))
273273
if reparameterize:
274-
chi2 = pm.ChiSquared(name + "_chi2_", self.nu)
275-
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs)
276-
f = pm.Deterministic(name, (at.sqrt(self.nu) / chi2) * (mu + cholesky(cov).dot(v)))
274+
size = infer_size(X, kwargs.pop("size", None))
275+
v = pm.StudentT(name + "_rotated_", mu=0.0, sigma=1.0, nu=self.nu, size=size, **kwargs)
276+
f = pm.Deterministic(name, mu + cholesky(cov).dot(v))
277277
else:
278-
f = pm.MvStudentT(name, nu=self.nu, mu=mu, cov=cov, size=shape, **kwargs)
278+
f = pm.MvStudentT(name, nu=self.nu, mu=mu, cov=cov, **kwargs)
279279
return f
280280

281281
def prior(self, name, X, reparameterize=True, **kwargs):
@@ -436,7 +436,7 @@ def marginal_likelihood(self, name, X, y, noise, is_observed=True, **kwargs):
436436
if is_observed:
437437
return pm.MvNormal(name, mu=mu, cov=cov, observed=y, **kwargs)
438438
else:
439-
# shape = infer_shape(X, kwargs.pop("shape", None))
439+
# size = infer_size(X, kwargs.pop("size", None))
440440
return pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
441441

442442
def _get_given_vals(self, given):
@@ -974,8 +974,8 @@ def conditional(self, name, Xnew, **kwargs):
974974
constructor.
975975
"""
976976
mu, cov = self._build_conditional(Xnew)
977-
shape = infer_shape(Xnew, kwargs.pop("shape", None))
978-
return pm.MvNormal(name, mu=mu, cov=cov, size=shape, **kwargs)
977+
size = infer_size(Xnew, kwargs.pop("size", None))
978+
return pm.MvNormal(name, mu=mu, cov=cov, size=size, **kwargs)
979979

980980

981981
@conditioned_vars(["Xs", "y", "sigma"])
@@ -1098,8 +1098,8 @@ def marginal_likelihood(self, name, Xs, y, sigma, is_observed=True, **kwargs):
10981098
if is_observed:
10991099
return pm.KroneckerNormal(name, mu=mu, covs=covs, sigma=sigma, observed=y, **kwargs)
11001100
else:
1101-
shape = np.prod([len(X) for X in Xs])
1102-
return pm.KroneckerNormal(name, mu=mu, covs=covs, sigma=sigma, size=shape, **kwargs)
1101+
size = np.prod([len(X) for X in Xs])
1102+
return pm.KroneckerNormal(name, mu=mu, covs=covs, sigma=sigma, size=size, **kwargs)
11031103

11041104
def _build_conditional(self, Xnew, pred_noise, diag):
11051105
Xs, y, sigma = self.Xs, self.y, self.sigma
@@ -1175,8 +1175,8 @@ def conditional(self, name, Xnew, pred_noise=False, **kwargs):
11751175
constructor.
11761176
"""
11771177
mu, cov = self._build_conditional(Xnew, pred_noise, False)
1178-
shape = infer_shape(Xnew, kwargs.pop("shape", None))
1179-
return pm.MvNormal(name, mu=mu, cov=cov, size=shape, **kwargs)
1178+
size = infer_size(Xnew, kwargs.pop("size", None))
1179+
return pm.MvNormal(name, mu=mu, cov=cov, size=size, **kwargs)
11801180

11811181
def predict(self, Xnew, point=None, diag=False, pred_noise=False):
11821182
R"""

0 commit comments

Comments
 (0)