Skip to content

Commit 2df8082

Browse files
committed
Add test for batched marginalized variables
1 parent 3dfaea8 commit 2df8082

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,14 +368,14 @@ def transform_input(inputs):
368368

369369
rv_shape = constant_fold(tuple(rv.shape))
370370
rv_domain = get_domain_of_finite_discrete_rv(rv)
371-
rv_domain_tensor = pt.swapaxes(
371+
rv_domain_tensor = pt.moveaxis(
372372
pt.full(
373373
(*rv_shape, len(rv_domain)),
374374
rv_domain,
375375
dtype=rv.dtype,
376376
),
377-
axis1=0,
378-
axis2=-1,
377+
-1,
378+
0,
379379
)
380380

381381
marginalized_value = m.rvs_to_values[rv]
@@ -641,14 +641,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
641641
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
642642
marginalized_rv_shape = constant_fold(tuple(marginalized_rv.shape))
643643
marginalized_rv_domain = get_domain_of_finite_discrete_rv(marginalized_rv)
644-
marginalized_rv_domain_tensor = pt.swapaxes(
644+
marginalized_rv_domain_tensor = pt.moveaxis(
645645
pt.full(
646646
(*marginalized_rv_shape, len(marginalized_rv_domain)),
647647
marginalized_rv_domain,
648648
dtype=marginalized_rv.dtype,
649649
),
650-
axis1=0,
651-
axis2=-1,
650+
-1,
651+
0,
652652
)
653653

654654
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control

pymc_experimental/tests/model/test_marginal_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,30 @@ def true_logp(y, sigma):
297297
)
298298

299299

300+
def test_recover_batched_marginal():
301+
"""Test that marginalization works for batched random variables"""
302+
with MarginalModel() as m:
303+
sigma = pm.HalfNormal("sigma")
304+
idx = pm.Bernoulli("idx", p=0.7, shape=(2, 2))
305+
y = pm.Normal("y", mu=idx, sigma=sigma, shape=(2, 2))
306+
307+
m.marginalize([idx])
308+
309+
rng = np.random.default_rng(211)
310+
311+
with m:
312+
prior = pm.sample_prior_predictive(
313+
samples=20,
314+
random_seed=rng,
315+
return_inferencedata=False,
316+
)
317+
idata = InferenceData(
318+
posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior})
319+
)
320+
321+
idata = m.recover_marginals(idata, return_samples=True)
322+
323+
300324
def test_nested_recover_marginals():
301325
"""Test that marginalization works when there are nested marginalized RVs"""
302326

0 commit comments

Comments
 (0)