Skip to content

Commit 9d9daa4

Browse files
committed
Add logic for dealing with batched dims
1 parent c17ba4a commit 9d9daa4

File tree

2 files changed

+77
-43
lines changed

2 files changed

+77
-43
lines changed

pymc_experimental/model/marginal_model.py

+50-23
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17-
from pymc.util import dataset_to_point_list, treedict
17+
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
1818
from pytensor import Mode
1919
from pytensor.compile import SharedVariable
2020
from pytensor.compile.builders import OpFromGraph
@@ -233,9 +233,14 @@ def clone(self):
233233
m.marginalized_rvs = [vars_to_clone[rv] for rv in self.marginalized_rvs]
234234
return m
235235

236-
def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorVariable]]):
236+
def marginalize(
237+
self, rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]]
238+
):
237239
if not isinstance(rvs_to_marginalize, Sequence):
238240
rvs_to_marginalize = (rvs_to_marginalize,)
241+
rvs_to_marginalize = [
242+
self[var] if isinstance(var, str) else var for var in rvs_to_marginalize
243+
]
239244

240245
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
241246
for rv_to_marginalize in rvs_to_marginalize:
@@ -279,7 +284,12 @@ def unmarginalize(self, rvs_to_unmarginalize):
279284
self.register_rv(rv, name=rv.name)
280285

281286
def recover_marginals(
282-
self, idata, var_names=None, return_samples=True, extend_inferencedata=True
287+
self,
288+
idata,
289+
var_names=None,
290+
return_samples=True,
291+
extend_inferencedata=True,
292+
random_seed=None,
283293
):
284294
"""Computes posterior log-probabilities and samples of marginalized variables
285295
conditioned on parameters of the model given InferenceData with posterior group
@@ -299,6 +309,8 @@ def recover_marginals(
299309
If True, also return samples of the marginalized variables
300310
extend_inferencedata : bool, default True
301311
Whether to extend the original InferenceData or return a new one
312+
random_seed: int, array-like of int or SeedSequence, optional
313+
Seed used to generating samples
302314
303315
Returns
304316
-------
@@ -323,15 +335,19 @@ def recover_marginals(
323335
324336
"""
325337
if var_names is None:
326-
var_names = {var.name for var in self.marginalized_rvs}
327-
else:
328-
var_names = {var_names}
338+
var_names = [var.name for var in self.marginalized_rvs]
329339

340+
var_names = [var if isinstance(var, str) else var.name for var in var_names]
330341
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
331-
missing_names = var_names.difference(v.name for v in vars_to_recover)
342+
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
332343
if missing_names:
333344
raise ValueError(f"Unrecognized var_names: {missing_names}")
334345

346+
if return_samples and random_seed is not None:
347+
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
348+
else:
349+
seeds = [None] * len(vars_to_recover)
350+
335351
posterior = idata.posterior
336352

337353
# Remove Deterministics
@@ -351,9 +367,8 @@ def transform_input(inputs):
351367
posterior_pts = [transform_input(vs) for vs in posterior_pts]
352368

353369
rv_dict = {}
354-
rv_dims_dict = {}
355370

356-
for rv in vars_to_recover:
371+
for seed, rv in zip(seeds, vars_to_recover):
357372
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
358373
if not isinstance(rv.owner.op, supported_dists):
359374
raise NotImplementedError(
@@ -365,7 +380,21 @@ def transform_input(inputs):
365380
rv = m.vars_to_clone[rv]
366381
m.unmarginalize([rv])
367382
dependent_vars = find_conditional_dependent_rvs(rv, m.basic_RVs)
368-
joint_logp = m.logp(vars=dependent_vars + [rv])
383+
joint_logps = m.logp(vars=dependent_vars + [rv], sum=False)
384+
385+
marginalized_value = m.rvs_to_values[rv]
386+
other_values = [v for v in m.value_vars if v is not marginalized_value]
387+
388+
# Handle batch dims for marginalized value and its dependent RVs
389+
joint_logp = joint_logps[-1]
390+
for dv in joint_logps[:-1]:
391+
dbcast = dv.type.broadcastable
392+
mbcast = marginalized_value.type.broadcastable
393+
mbcast = (True,) * (len(dbcast) - len(mbcast)) + mbcast
394+
values_axis_bcast = [
395+
i for i, (m, v) in enumerate(zip(mbcast, dbcast)) if m and not v
396+
]
397+
joint_logp += dv.sum(values_axis_bcast)
369398

370399
rv_shape = constant_fold(tuple(rv.shape))
371400
rv_domain = get_domain_of_finite_discrete_rv(rv)
@@ -379,27 +408,28 @@ def transform_input(inputs):
379408
0,
380409
)
381410

382-
marginalized_value = m.rvs_to_values[rv]
383-
other_values = [v for v in m.value_vars if v is not marginalized_value]
384-
385411
joint_logps = vectorize_graph(
386412
joint_logp,
387413
replace={marginalized_value: rv_domain_tensor},
388414
)
415+
joint_logps = pt.moveaxis(joint_logps, 0, -1)
389416

390417
rv_loglike_fn = None
418+
joint_logps_norm = log_softmax(joint_logps, axis=0)
391419
if return_samples:
392420
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
393421
rv_loglike_fn = compile_pymc(
394422
inputs=other_values,
395-
outputs=[log_softmax(joint_logps, axis=0), sample_rv_outs],
423+
outputs=[joint_logps_norm, sample_rv_outs],
396424
on_unused_input="ignore",
425+
random_seed=seed,
397426
)
398427
else:
399428
rv_loglike_fn = compile_pymc(
400429
inputs=other_values,
401-
outputs=log_softmax(joint_logps, axis=0),
430+
outputs=joint_logps_norm,
402431
on_unused_input="ignore",
432+
random_seed=seed,
403433
)
404434

405435
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
@@ -409,18 +439,16 @@ def transform_input(inputs):
409439
if return_samples:
410440
logps, samples = zip(*logvs)
411441
logps = np.array(logps)
412-
rv_dict[rv.name] = np.reshape(
413-
samples, tuple(len(coord) for coord in stacked_dims.values())
442+
samples = np.array(samples)
443+
rv_dict[rv.name] = samples.reshape(
444+
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
414445
)
415-
rv_dims_dict[rv.name] = sample_dims
416446
else:
417447
logps = np.array(logvs)
418448

419-
rv_dict["lp_" + rv.name] = np.reshape(
420-
logps,
449+
rv_dict["lp_" + rv.name] = logps.reshape(
421450
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
422451
)
423-
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
424452

425453
coords, dims = coords_and_dims_for_inferencedata(self)
426454
rv_dataset = dict_to_dataset(
@@ -433,8 +461,7 @@ def transform_input(inputs):
433461
)
434462

435463
if extend_inferencedata:
436-
rv_dict = {k: (rv_dims_dict[k], v) for (k, v) in rv_dict.items()}
437-
idata = idata.posterior.assign(**rv_dict)
464+
idata.posterior = idata.posterior.assign(rv_dataset)
438465
return idata
439466
else:
440467
return rv_dataset

pymc_experimental/tests/model/test_marginal_model.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,11 @@ def test_recover_marginals_basic():
290290
idata = InferenceData(posterior=dict_to_dataset(prior))
291291

292292
idata = m.recover_marginals(idata, return_samples=True)
293-
assert "k" in idata
294-
assert "lp_k" in idata
295-
assert idata.k.shape == idata.y.shape
296-
assert idata.lp_k.shape == idata.k.shape + (len(p),)
293+
post = idata.posterior
294+
assert "k" in post
295+
assert "lp_k" in post
296+
assert post.k.shape == post.y.shape
297+
assert post.lp_k.shape == post.k.shape + (len(p),)
297298

298299
def true_logp(y, sigma):
299300
y = y.repeat(len(p)).reshape(len(y), -1)
@@ -307,17 +308,17 @@ def true_logp(y, sigma):
307308
)
308309

309310
np.testing.assert_almost_equal(
310-
true_logp(idata.y.values.flatten(), idata.sigma.values.flatten()),
311-
idata.lp_k[0].values,
311+
true_logp(post.y.values.flatten(), post.sigma.values.flatten()),
312+
post.lp_k[0].values,
312313
)
313314

314315

315316
def test_recover_batched_marginal():
316317
"""Test that marginalization works for batched random variables"""
317318
with MarginalModel() as m:
318319
sigma = pm.HalfNormal("sigma")
319-
idx = pm.Bernoulli("idx", p=0.7, shape=(2, 3))
320-
y = pm.Normal("y", mu=idx, sigma=sigma, shape=(2, 3))
320+
idx = pm.Bernoulli("idx", p=0.7, shape=(3, 2))
321+
y = pm.Normal("y", mu=idx, sigma=sigma, shape=(3, 2))
321322

322323
m.marginalize([idx])
323324

@@ -334,6 +335,11 @@ def test_recover_batched_marginal():
334335
)
335336

336337
idata = m.recover_marginals(idata, return_samples=True)
338+
post = idata.posterior
339+
assert "idx" in post
340+
assert "lp_idx" in post
341+
assert post.idx.shape == post.y.shape
342+
assert post.lp_idx.shape == post.idx.shape + (2,)
337343

338344

339345
def test_nested_recover_marginals():
@@ -357,23 +363,24 @@ def test_nested_recover_marginals():
357363
idata = InferenceData(posterior=dict_to_dataset(prior))
358364

359365
idata = m.recover_marginals(idata, return_samples=True)
360-
assert "idx" in idata
361-
assert "lp_idx" in idata
362-
assert idata.idx.shape == idata.y.shape
363-
assert idata.lp_idx.shape == idata.idx.shape + (2,)
364-
assert "sub_idx" in idata
365-
assert "lp_sub_idx" in idata
366-
assert idata.sub_idx.shape == idata.y.shape
367-
assert idata.lp_sub_idx.shape == idata.sub_idx.shape + (2,)
366+
post = idata.posterior
367+
assert "idx" in post
368+
assert "lp_idx" in post
369+
assert post.idx.shape == post.y.shape
370+
assert post.lp_idx.shape == post.idx.shape + (2,)
371+
assert "sub_idx" in post
372+
assert "lp_sub_idx" in post
373+
assert post.sub_idx.shape == post.y.shape
374+
assert post.lp_sub_idx.shape == post.sub_idx.shape + (2,)
368375

369376
def true_idx_logp(y):
370377
idx_0 = np.log(0.85 * 0.25 * norm.pdf(y, loc=0) + 0.15 * 0.25 * norm.pdf(y, loc=1))
371378
idx_1 = np.log(0.05 * 0.75 * norm.pdf(y, loc=1) + 0.95 * 0.75 * norm.pdf(y, loc=2))
372379
return log_softmax(np.stack([idx_0, idx_1]).T, axis=1)
373380

374381
np.testing.assert_almost_equal(
375-
true_idx_logp(idata.y.values.flatten()),
376-
idata.lp_idx[0].values,
382+
true_idx_logp(post.y.values.flatten()),
383+
post.lp_idx[0].values,
377384
)
378385

379386
def true_sub_idx_logp(y):
@@ -382,8 +389,8 @@ def true_sub_idx_logp(y):
382389
return log_softmax(np.stack([sub_idx_0, sub_idx_1]).T, axis=1)
383390

384391
np.testing.assert_almost_equal(
385-
true_sub_idx_logp(idata.y.values.flatten()),
386-
idata.lp_sub_idx[0].values,
392+
true_sub_idx_logp(post.y.values.flatten()),
393+
post.lp_sub_idx[0].values,
387394
)
388395

389396

0 commit comments

Comments
 (0)