14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
16
from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
- from pymc .util import dataset_to_point_list , treedict
17
+ from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
18
18
from pytensor import Mode
19
19
from pytensor .compile import SharedVariable
20
20
from pytensor .compile .builders import OpFromGraph
@@ -284,7 +284,12 @@ def unmarginalize(self, rvs_to_unmarginalize):
284
284
self .register_rv (rv , name = rv .name )
285
285
286
286
def recover_marginals (
287
- self , idata , var_names = None , return_samples = True , extend_inferencedata = True
287
+ self ,
288
+ idata ,
289
+ var_names = None ,
290
+ return_samples = True ,
291
+ extend_inferencedata = True ,
292
+ random_seed = None ,
288
293
):
289
294
"""Computes posterior log-probabilities and samples of marginalized variables
290
295
conditioned on parameters of the model given InferenceData with posterior group
@@ -304,6 +309,8 @@ def recover_marginals(
304
309
If True, also return samples of the marginalized variables
305
310
extend_inferencedata : bool, default True
306
311
Whether to extend the original InferenceData or return a new one
312
+ random_seed: int, array-like of int or SeedSequence, optional
313
+ Seed used to generating samples
307
314
308
315
Returns
309
316
-------
@@ -328,16 +335,19 @@ def recover_marginals(
328
335
329
336
"""
330
337
if var_names is None :
331
- var_names = {var .name for var in self .marginalized_rvs }
332
- else :
333
- var_names = {var_names }
338
+ var_names = [var .name for var in self .marginalized_rvs ]
334
339
335
- var_names = { var if isinstance (var , str ) else var .name for var in var_names }
340
+ var_names = [ var if isinstance (var , str ) else var .name for var in var_names ]
336
341
vars_to_recover = [v for v in self .marginalized_rvs if v .name in var_names ]
337
- missing_names = var_names . difference ( v .name for v in vars_to_recover )
342
+ missing_names = [ v .name for v in vars_to_recover if v not in self . marginalized_rvs ]
338
343
if missing_names :
339
344
raise ValueError (f"Unrecognized var_names: { missing_names } " )
340
345
346
+ if return_samples and random_seed is not None :
347
+ seeds = _get_seeds_per_chain (random_seed , len (vars_to_recover ))
348
+ else :
349
+ seeds = [None ] * len (vars_to_recover )
350
+
341
351
posterior = idata .posterior
342
352
343
353
# Remove Deterministics
@@ -357,9 +367,8 @@ def transform_input(inputs):
357
367
posterior_pts = [transform_input (vs ) for vs in posterior_pts ]
358
368
359
369
rv_dict = {}
360
- rv_dims_dict = {}
361
370
362
- for rv in vars_to_recover :
371
+ for seed , rv in zip ( seeds , vars_to_recover ) :
363
372
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
364
373
if not isinstance (rv .owner .op , supported_dists ):
365
374
raise NotImplementedError (
@@ -406,18 +415,21 @@ def transform_input(inputs):
406
415
joint_logps = pt .moveaxis (joint_logps , 0 , - 1 )
407
416
408
417
rv_loglike_fn = None
418
+ joint_logps_norm = log_softmax (joint_logps , axis = 0 )
409
419
if return_samples :
410
420
sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
411
421
rv_loglike_fn = compile_pymc (
412
422
inputs = other_values ,
413
- outputs = [log_softmax ( joint_logps , axis = 0 ) , sample_rv_outs ],
423
+ outputs = [joint_logps_norm , sample_rv_outs ],
414
424
on_unused_input = "ignore" ,
425
+ random_seed = seed ,
415
426
)
416
427
else :
417
428
rv_loglike_fn = compile_pymc (
418
429
inputs = other_values ,
419
- outputs = log_softmax ( joint_logps , axis = 0 ) ,
430
+ outputs = joint_logps_norm ,
420
431
on_unused_input = "ignore" ,
432
+ random_seed = seed ,
421
433
)
422
434
423
435
logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
@@ -431,14 +443,12 @@ def transform_input(inputs):
431
443
rv_dict [rv .name ] = samples .reshape (
432
444
tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
433
445
)
434
- rv_dims_dict [rv .name ] = sample_dims
435
446
else :
436
447
logps = np .array (logvs )
437
448
438
449
rv_dict ["lp_" + rv .name ] = logps .reshape (
439
450
tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
440
451
)
441
- rv_dims_dict ["lp_" + rv .name ] = sample_dims + ("lp_" + rv .name + "_dims" ,)
442
452
443
453
coords , dims = coords_and_dims_for_inferencedata (self )
444
454
rv_dataset = dict_to_dataset (
0 commit comments