Skip to content

Commit b45a72a

Browse files
committed
Add random_seed
1 parent 4ae0421 commit b45a72a

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

pymc_experimental/model/marginal_model.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
17-
from pymc.util import dataset_to_point_list, treedict
17+
from pymc.util import _get_seeds_per_chain, dataset_to_point_list, treedict
1818
from pytensor import Mode
1919
from pytensor.compile import SharedVariable
2020
from pytensor.compile.builders import OpFromGraph
@@ -284,7 +284,12 @@ def unmarginalize(self, rvs_to_unmarginalize):
284284
self.register_rv(rv, name=rv.name)
285285

286286
def recover_marginals(
287-
self, idata, var_names=None, return_samples=True, extend_inferencedata=True
287+
self,
288+
idata,
289+
var_names=None,
290+
return_samples=True,
291+
extend_inferencedata=True,
292+
random_seed=None,
288293
):
289294
"""Computes posterior log-probabilities and samples of marginalized variables
290295
conditioned on parameters of the model given InferenceData with posterior group
@@ -304,6 +309,8 @@ def recover_marginals(
304309
If True, also return samples of the marginalized variables
305310
extend_inferencedata : bool, default True
306311
Whether to extend the original InferenceData or return a new one
312+
random_seed: int, array-like of int or SeedSequence, optional
313+
Seed used to generating samples
307314
308315
Returns
309316
-------
@@ -328,16 +335,19 @@ def recover_marginals(
328335
329336
"""
330337
if var_names is None:
331-
var_names = {var.name for var in self.marginalized_rvs}
332-
else:
333-
var_names = {var_names}
338+
var_names = [var.name for var in self.marginalized_rvs]
334339

335-
var_names = {var if isinstance(var, str) else var.name for var in var_names}
340+
var_names = [var if isinstance(var, str) else var.name for var in var_names]
336341
vars_to_recover = [v for v in self.marginalized_rvs if v.name in var_names]
337-
missing_names = var_names.difference(v.name for v in vars_to_recover)
342+
missing_names = [v.name for v in vars_to_recover if v not in self.marginalized_rvs]
338343
if missing_names:
339344
raise ValueError(f"Unrecognized var_names: {missing_names}")
340345

346+
if return_samples and random_seed is not None:
347+
seeds = _get_seeds_per_chain(random_seed, len(vars_to_recover))
348+
else:
349+
seeds = [None] * len(vars_to_recover)
350+
341351
posterior = idata.posterior
342352

343353
# Remove Deterministics
@@ -357,9 +367,8 @@ def transform_input(inputs):
357367
posterior_pts = [transform_input(vs) for vs in posterior_pts]
358368

359369
rv_dict = {}
360-
rv_dims_dict = {}
361370

362-
for rv in vars_to_recover:
371+
for seed, rv in zip(seeds, vars_to_recover):
363372
supported_dists = (Bernoulli, Categorical, DiscreteUniform)
364373
if not isinstance(rv.owner.op, supported_dists):
365374
raise NotImplementedError(
@@ -406,18 +415,21 @@ def transform_input(inputs):
406415
joint_logps = pt.moveaxis(joint_logps, 0, -1)
407416

408417
rv_loglike_fn = None
418+
joint_logps_norm = log_softmax(joint_logps, axis=0)
409419
if return_samples:
410420
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
411421
rv_loglike_fn = compile_pymc(
412422
inputs=other_values,
413-
outputs=[log_softmax(joint_logps, axis=0), sample_rv_outs],
423+
outputs=[joint_logps_norm, sample_rv_outs],
414424
on_unused_input="ignore",
425+
random_seed=seed,
415426
)
416427
else:
417428
rv_loglike_fn = compile_pymc(
418429
inputs=other_values,
419-
outputs=log_softmax(joint_logps, axis=0),
430+
outputs=joint_logps_norm,
420431
on_unused_input="ignore",
432+
random_seed=seed,
421433
)
422434

423435
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
@@ -431,14 +443,12 @@ def transform_input(inputs):
431443
rv_dict[rv.name] = samples.reshape(
432444
tuple(len(coord) for coord in stacked_dims.values()) + samples.shape[1:],
433445
)
434-
rv_dims_dict[rv.name] = sample_dims
435446
else:
436447
logps = np.array(logvs)
437448

438449
rv_dict["lp_" + rv.name] = logps.reshape(
439450
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:],
440451
)
441-
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
442452

443453
coords, dims = coords_and_dims_for_inferencedata(self)
444454
rv_dataset = dict_to_dataset(

0 commit comments

Comments
 (0)