Skip to content

Commit 37230d3

Browse files
author
Chris Fonnesbeck
committed
Merge branch 'main' into auto_versioning
2 parents b9f2adf + c858f0f commit 37230d3

File tree

7 files changed

+137
-147
lines changed

7 files changed

+137
-147
lines changed

pymc/distributions/multivariate.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,32 +2200,23 @@ def make_node(self, rng, size, dtype, alpha, K):
22002200
alpha = at.as_tensor_variable(alpha)
22012201
K = at.as_tensor_variable(intX(K))
22022202

2203-
if alpha.ndim > 0:
2204-
raise ValueError("The concentration parameter needs to be a scalar.")
2205-
22062203
if K.ndim > 0:
22072204
raise ValueError("K must be a scalar.")
22082205

22092206
return super().make_node(rng, size, dtype, alpha, K)
22102207

2211-
def _infer_shape(self, size, dist_params, param_shapes=None):
2212-
alpha, K = dist_params
2213-
2214-
size = tuple(size)
2215-
2216-
return size + (K + 1,)
2208+
def _supp_shape_from_params(self, dist_params, **kwargs):
2209+
K = dist_params[1]
2210+
return (K + 1,)
22172211

22182212
@classmethod
22192213
def rng_fn(cls, rng, alpha, K, size):
22202214
if K < 0:
22212215
raise ValueError("K needs to be positive.")
22222216

2223-
if size is None:
2224-
size = (K,)
2225-
elif isinstance(size, int):
2226-
size = (size,) + (K,)
2227-
else:
2228-
size = tuple(size) + (K,)
2217+
size = to_tuple(size) if size is not None else alpha.shape
2218+
size = size + (K,)
2219+
alpha = alpha[..., np.newaxis]
22292220

22302221
betas = rng.beta(1, alpha, size=size)
22312222

@@ -2294,9 +2285,10 @@ def dist(cls, alpha, K, *args, **kwargs):
22942285
return super().dist([alpha, K], **kwargs)
22952286

22962287
def moment(rv, size, alpha, K):
2288+
alpha = alpha[..., np.newaxis]
22972289
moment = (alpha / (1 + alpha)) ** at.arange(K)
22982290
moment *= 1 / (1 + alpha)
2299-
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
2291+
moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1)
23002292
if not rv_size_is_none(size):
23012293
moment_size = at.concatenate(
23022294
[

pymc/gp/gp.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,13 @@ def __init__(self, approx="VFE", *, mean_func=Zero(), cov_func=Constant(0.0)):
685685
super().__init__(mean_func=mean_func, cov_func=cov_func)
686686

687687
def __add__(self, other):
688-
# new_gp will default to FITC approx
689688
new_gp = super().__add__(other)
690-
# make sure new gp has correct approx
691689
if not self.approx == other.approx:
692690
raise TypeError("Cannot add GPs with different approximations")
693691
new_gp.approx = self.approx
694692
return new_gp
695693

696-
# Use y as first argument, so that we can use functools.partial
697-
# in marginal_likelihood instead of lambda. This makes pickling
698-
# possible.
699-
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
694+
def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
700695
sigma2 = at.square(sigma)
701696
Kuu = self.cov_func(Xu)
702697
Kuf = self.cov_func(Xu, X)
@@ -725,9 +720,7 @@ def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
725720
quadratic = 0.5 * (at.dot(r, r_l) - at.dot(c, c))
726721
return -1.0 * (constant + logdet + quadratic + trace)
727722

728-
def marginal_likelihood(
729-
self, name, X, Xu, y, noise=None, is_observed=True, jitter=JITTER_DEFAULT, **kwargs
730-
):
723+
def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT, **kwargs):
731724
R"""
732725
Returns the approximate marginal likelihood distribution, given the input
733726
locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -747,9 +740,6 @@ def marginal_likelihood(
747740
noise. Must have shape `(n, )`.
748741
noise: scalar, Variable
749742
Standard deviation of the Gaussian noise.
750-
is_observed: bool
751-
Whether to set `y` as an `observed` variable in the `model`.
752-
Default is `True`.
753743
jitter: scalar
754744
A small correction added to the diagonal of positive semi-definite
755745
covariance matrices to ensure numerical stability.
@@ -767,38 +757,8 @@ def marginal_likelihood(
767757
else:
768758
self.sigma = noise
769759

770-
if is_observed:
771-
return pm.DensityDist(
772-
name,
773-
X,
774-
Xu,
775-
self.sigma,
776-
jitter,
777-
logp=self._build_marginal_likelihood_logp,
778-
observed=y,
779-
ndims_params=[2, 2, 0],
780-
size=X.shape[0],
781-
**kwargs,
782-
)
783-
else:
784-
warnings.warn(
785-
"The 'is_observed' argument has been deprecated. If the GP is "
786-
"unobserved use gp.Latent instead.",
787-
FutureWarning,
788-
)
789-
return pm.DensityDist(
790-
name,
791-
X,
792-
Xu,
793-
self.sigma,
794-
jitter,
795-
logp=self._build_marginal_likelihood_logp,
796-
observed=y,
797-
ndims_params=[2, 2, 0],
798-
# ndim_supp=1,
799-
size=X.shape[0],
800-
**kwargs,
801-
)
760+
approx_loglik = self._build_marginal_likelihood_loglik(y, X, Xu, noise, jitter)
761+
pm.Potential(f"marginalapprox_loglik_{name}", approx_loglik, **kwargs)
802762

803763
def _build_conditional(
804764
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter

pymc/tests/test_distributions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from aeppl.logprob import ParameterValueError
2525
from aesara.tensor.random.utils import broadcast_params
2626

27+
from pymc.aesaraf import compile_pymc
2728
from pymc.distributions.continuous import get_tau_sigma
2829
from pymc.util import UNSET
2930

@@ -953,6 +954,17 @@ def test_hierarchical_obs_logp():
953954
assert not any(isinstance(o, RandomVariable) for o in ops)
954955

955956

957+
@pytest.fixture(scope="module")
958+
def stickbreakingweights_logpdf():
959+
_value = at.vector()
960+
_alpha = at.scalar()
961+
_k = at.iscalar()
962+
_logp = logp(StickBreakingWeights.dist(_alpha, _k), _value)
963+
core_fn = compile_pymc([_value, _alpha, _k], _logp)
964+
965+
return np.vectorize(core_fn, signature="(n),(),()->()")
966+
967+
956968
class TestMatchesScipy:
957969
def test_uniform(self):
958970
check_logp(
@@ -2318,6 +2330,25 @@ def test_stickbreakingweights_invalid(self):
23182330
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
23192331
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
23202332

2333+
@pytest.mark.parametrize(
2334+
"alpha,K",
2335+
[
2336+
(np.array([0.5, 1.0, 2.0]), 3),
2337+
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
2338+
],
2339+
)
2340+
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
2341+
value = pm.StickBreakingWeights.dist(alpha, K).eval()
2342+
with Model():
2343+
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
2344+
pt = {"sbw": value}
2345+
assert_almost_equal(
2346+
pm.logp(sbw, value).eval(),
2347+
stickbreakingweights_logpdf(value, alpha, K),
2348+
decimal=select_by_precision(float64=6, float32=2),
2349+
err_msg=str(pt),
2350+
)
2351+
23212352
@aesara.config.change_flags(compute_test_value="raise")
23222353
def test_categorical_bounds(self):
23232354
with Model():

pymc/tests/test_distributions_moments.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected):
11661166
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
11671167
),
11681168
),
1169+
(
1170+
np.array([1, 3]),
1171+
11,
1172+
None,
1173+
np.array(
1174+
[
1175+
np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11),
1176+
np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11),
1177+
]
1178+
),
1179+
),
1180+
(
1181+
np.array([1, 3, 5]),
1182+
9,
1183+
(5, 3),
1184+
np.full(
1185+
shape=(5, 3, 10),
1186+
fill_value=np.array(
1187+
[
1188+
np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9),
1189+
np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9),
1190+
np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9),
1191+
]
1192+
),
1193+
),
1194+
),
11691195
],
11701196
)
11711197
def test_stickbreakingweights_moment(alpha, K, size, expected):

pymc/tests/test_distributions_random.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,18 @@ def check_basic_properties(self):
13291329
assert np.all(draws <= 1)
13301330

13311331

1332+
class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom):
1333+
pymc_dist = pm.StickBreakingWeights
1334+
pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
1335+
expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
1336+
sizes_to_check = [None, (3,), (5, 3)]
1337+
sizes_expected = [(3, 20), (3, 20), (5, 3, 20)]
1338+
checks_to_run = [
1339+
"check_pymc_params_match_rv_op",
1340+
"check_rv_size",
1341+
]
1342+
1343+
13321344
class TestCategorical(BaseTestDistributionRandom):
13331345
pymc_dist = pm.Categorical
13341346
pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}

pymc/tests/test_gp.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -846,63 +846,71 @@ def testLatent2(self):
846846

847847
class TestMarginalVsMarginalApprox:
848848
R"""
849-
Compare logp of models Marginal and MarginalApprox.
850-
Should be nearly equal when inducing points are same as inputs.
849+
Compare test fits of models Marginal and MarginalApprox.
851850
"""
852851

853852
def setup_method(self):
854-
X = np.random.randn(50, 3)
855-
y = np.random.randn(50)
856-
Xnew = np.random.randn(60, 3)
857-
pnew = np.random.randn(60)
858-
with pm.Model() as model:
859-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
860-
mean_func = pm.gp.mean.Constant(0.5)
861-
gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
862-
sigma = 0.1
863-
f = gp.marginal_likelihood("f", X, y, noise=sigma)
864-
p = gp.conditional("p", Xnew)
865-
self.logp = model.compile_logp()({"p": pnew})
866-
self.X = X
867-
self.Xnew = Xnew
868-
self.y = y
869-
self.sigma = sigma
870-
self.pnew = pnew
871-
self.gp = gp
853+
self.sigma = 0.1
854+
self.x = np.linspace(-5, 5, 30)
855+
self.y = np.random.normal(0.25 * self.x, self.sigma)
856+
with pm.Model() as model:
857+
cov_func = pm.gp.cov.Linear(1, c=0.0)
858+
c = pm.Normal("c", mu=20.0, sigma=100.0) # far from true value
859+
mean_func = pm.gp.mean.Constant(c)
860+
self.gp = pm.gp.Marginal(mean_func=mean_func, cov_func=cov_func)
861+
sigma = pm.HalfNormal("sigma", sigma=100)
862+
self.gp.marginal_likelihood("lik", self.x[:, None], self.y, sigma)
863+
self.map_full = pm.find_MAP(method="bfgs") # bfgs seems to work much better than lbfgsb
864+
865+
self.x_new = np.linspace(-6, 6, 20)
866+
867+
# Include additive Gaussian noise, return diagonal of predicted covariance matrix
868+
with model:
869+
self.pred_mu, self.pred_var = self.gp.predict(
870+
self.x_new[:, None], point=self.map_full, pred_noise=True, diag=True
871+
)
872872

873-
@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
874-
def testApproximations(self, approx):
875-
with pm.Model() as model:
876-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
877-
mean_func = pm.gp.mean.Constant(0.5)
878-
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
879-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
880-
p = gp.conditional("p", self.Xnew)
881-
approx_logp = model.compile_logp()({"p": self.pnew})
882-
npt.assert_allclose(approx_logp, self.logp, atol=0, rtol=1e-2)
873+
# Dont include additive Gaussian noise, return full predicted covariance matrix
874+
with model:
875+
self.pred_mu, self.pred_covar = self.gp.predict(
876+
self.x_new[:, None], point=self.map_full, pred_noise=False, diag=False
877+
)
883878

884879
@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
885-
def testPredictVar(self, approx):
880+
def test_fits_and_preds(self, approx):
881+
"""Get MAP estimate for GP approximation, compare results and predictions to what's returned
882+
by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
883+
values of the unknown parameters.
884+
"""
885+
886886
with pm.Model() as model:
887-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
888-
mean_func = pm.gp.mean.Constant(0.5)
887+
cov_func = pm.gp.cov.Linear(1, c=0.0)
888+
c = pm.Normal("c", mu=20.0, sigma=100.0, initval=-500.0)
889+
mean_func = pm.gp.mean.Constant(c)
889890
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
890-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
891-
mu1, var1 = self.gp.predict(self.Xnew, diag=True)
892-
mu2, var2 = gp.predict(self.Xnew, diag=True)
893-
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
894-
npt.assert_allclose(var1, var2, atol=0, rtol=1e-3)
891+
sigma = pm.HalfNormal("sigma", sigma=100, initval=50.0)
892+
gp.marginal_likelihood("lik", self.x[:, None], self.x[:, None], self.y, sigma)
893+
map_approx = pm.find_MAP(method="bfgs")
894+
895+
# Check MAP gets approximately correct result
896+
npt.assert_allclose(self.map_full["c"], map_approx["c"], atol=0.01, rtol=0.1)
897+
npt.assert_allclose(self.map_full["sigma"], map_approx["sigma"], atol=0.01, rtol=0.1)
898+
899+
# Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
900+
with model:
901+
pred_mu_approx, pred_var_approx = gp.predict(
902+
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
903+
)
904+
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
905+
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)
895906

896-
def testPredictCov(self):
897-
with pm.Model() as model:
898-
cov_func = pm.gp.cov.ExpQuad(3, [0.1, 0.2, 0.3])
899-
mean_func = pm.gp.mean.Constant(0.5)
900-
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx="DTC")
901-
f = gp.marginal_likelihood("f", self.X, self.X, self.y, self.sigma)
902-
mu1, cov1 = self.gp.predict(self.Xnew, pred_noise=True)
903-
mu2, cov2 = gp.predict(self.Xnew, pred_noise=True)
904-
npt.assert_allclose(mu1, mu2, atol=0, rtol=1e-3)
905-
npt.assert_allclose(cov1, cov2, atol=0, rtol=1e-3)
907+
# Check that predict (and conditional) work, no noise, full pred covariance.
908+
with model:
909+
pred_mu_approx, pred_var_approx = gp.predict(
910+
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
911+
)
912+
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
913+
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)
906914

907915

908916
class TestGPAdditive:

0 commit comments

Comments
 (0)