Skip to content

Commit 28dfb0f

Browse files
committed
First crack at transforms
1 parent e4c9db7 commit 28dfb0f

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

pymc_experimental/tests/model/test_marginal_model.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pymc.logprob.abstract import _logprob
1313
from pymc.util import UNSET
1414
from scipy.special import logsumexp
15-
from scipy.stats import norm
15+
from scipy.stats import halfnorm, norm
1616

1717
from pymc_experimental.model.marginal_model import (
1818
FiniteDiscreteMarginalRV,
@@ -255,11 +255,12 @@ def test_marginalized_change_point_model_sampling(disaster_model):
255255

256256
def test_recover_marginals_basic():
257257
with MarginalModel() as m:
258+
sigma = pm.HalfNormal("sigma", transform=None)
258259
p = np.array([0.5, 0.2, 0.3])
259260
k = pm.Categorical("k", p=p)
260261
mu = np.array([-3.0, 0.0, 3.0])
261262
mu_ = pt.as_tensor_variable(mu)
262-
y = pm.Normal("y", mu=mu_[k])
263+
y = pm.Normal("y", mu=mu_[k], sigma=sigma)
263264

264265
m.marginalize([k])
265266

@@ -279,17 +280,17 @@ def test_recover_marginals_basic():
279280
assert idata.k.shape == idata.y.shape
280281
assert idata.lp_k.shape == idata.k.shape + (len(p),)
281282

282-
def true_logp(y):
283+
def true_logp(y, sigma):
283284
y = y.repeat(len(p)).reshape(len(y), -1)
284-
return np.log(p) + norm.logpdf(y, loc=mu)
285+
sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
286+
return np.log(p) + norm.logpdf(y, loc=mu, scale=sigma) + halfnorm.logpdf(sigma)
285287

286288
np.testing.assert_almost_equal(
287-
true_logp(idata.y.values.flatten()),
289+
true_logp(idata.y.values.flatten(), idata.sigma.values.flatten()),
288290
idata.lp_k[0].values,
289291
)
290292

291293

292-
@pytest.mark.filterwarnings("error")
293294
def test_nested_recover_marginals():
294295
"""Test that marginalization works when there are nested marginalized RVs"""
295296

0 commit comments

Comments
 (0)