12
12
from pymc .logprob .abstract import _logprob
13
13
from pymc .util import UNSET
14
14
from scipy .special import logsumexp
15
- from scipy .stats import norm
15
+ from scipy .stats import halfnorm , norm
16
16
17
17
from pymc_experimental .model .marginal_model import (
18
18
FiniteDiscreteMarginalRV ,
@@ -255,11 +255,12 @@ def test_marginalized_change_point_model_sampling(disaster_model):
255
255
256
256
def test_recover_marginals_basic ():
257
257
with MarginalModel () as m :
258
+ sigma = pm .HalfNormal ("sigma" , transform = None )
258
259
p = np .array ([0.5 , 0.2 , 0.3 ])
259
260
k = pm .Categorical ("k" , p = p )
260
261
mu = np .array ([- 3.0 , 0.0 , 3.0 ])
261
262
mu_ = pt .as_tensor_variable (mu )
262
- y = pm .Normal ("y" , mu = mu_ [k ])
263
+ y = pm .Normal ("y" , mu = mu_ [k ], sigma = sigma )
263
264
264
265
m .marginalize ([k ])
265
266
@@ -279,17 +280,17 @@ def test_recover_marginals_basic():
279
280
assert idata .k .shape == idata .y .shape
280
281
assert idata .lp_k .shape == idata .k .shape + (len (p ),)
281
282
282
- def true_logp (y ):
283
+ def true_logp (y , sigma ):
283
284
y = y .repeat (len (p )).reshape (len (y ), - 1 )
284
- return np .log (p ) + norm .logpdf (y , loc = mu )
285
+ sigma = sigma .repeat (len (p )).reshape (len (sigma ), - 1 )
286
+ return np .log (p ) + norm .logpdf (y , loc = mu , scale = sigma ) + halfnorm .logpdf (sigma )
285
287
286
288
np .testing .assert_almost_equal (
287
- true_logp (idata .y .values .flatten ()),
289
+ true_logp (idata .y .values .flatten (), idata . sigma . values . flatten () ),
288
290
idata .lp_k [0 ].values ,
289
291
)
290
292
291
293
292
- @pytest .mark .filterwarnings ("error" )
293
294
def test_nested_recover_marginals ():
294
295
"""Test that marginalization works when there are nested marginalized RVs"""
295
296
0 commit comments