14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
16
from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
- from pymc .util import dataset_to_point_list , treedict
17
+ from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
18
18
from pytensor import Mode
19
19
from pytensor .compile import SharedVariable
20
20
from pytensor .compile .builders import OpFromGraph
@@ -233,9 +233,14 @@ def clone(self):
233
233
m .marginalized_rvs = [vars_to_clone [rv ] for rv in self .marginalized_rvs ]
234
234
return m
235
235
236
- def marginalize (self , rvs_to_marginalize : Union [TensorVariable , Sequence [TensorVariable ]]):
236
+ def marginalize (
237
+ self , rvs_to_marginalize : Union [TensorVariable , str , Sequence [TensorVariable ], Sequence [str ]]
238
+ ):
237
239
if not isinstance (rvs_to_marginalize , Sequence ):
238
240
rvs_to_marginalize = (rvs_to_marginalize ,)
241
+ rvs_to_marginalize = [
242
+ self [var ] if isinstance (var , str ) else var for var in rvs_to_marginalize
243
+ ]
239
244
240
245
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
241
246
for rv_to_marginalize in rvs_to_marginalize :
@@ -279,7 +284,12 @@ def unmarginalize(self, rvs_to_unmarginalize):
279
284
self .register_rv (rv , name = rv .name )
280
285
281
286
def recover_marginals (
282
- self , idata , var_names = None , return_samples = True , extend_inferencedata = True
287
+ self ,
288
+ idata ,
289
+ var_names = None ,
290
+ return_samples = True ,
291
+ extend_inferencedata = True ,
292
+ random_seed = None ,
283
293
):
284
294
"""Computes posterior log-probabilities and samples of marginalized variables
285
295
conditioned on parameters of the model given InferenceData with posterior group
@@ -299,6 +309,8 @@ def recover_marginals(
299
309
If True, also return samples of the marginalized variables
300
310
extend_inferencedata : bool, default True
301
311
Whether to extend the original InferenceData or return a new one
312
+ random_seed: int, array-like of int or SeedSequence, optional
313
+ Seed used to generating samples
302
314
303
315
Returns
304
316
-------
@@ -323,15 +335,19 @@ def recover_marginals(
323
335
324
336
"""
325
337
if var_names is None :
326
- var_names = {var .name for var in self .marginalized_rvs }
327
- else :
328
- var_names = {var_names }
338
+ var_names = [var .name for var in self .marginalized_rvs ]
329
339
340
+ var_names = [var if isinstance (var , str ) else var .name for var in var_names ]
330
341
vars_to_recover = [v for v in self .marginalized_rvs if v .name in var_names ]
331
- missing_names = var_names . difference ( v .name for v in vars_to_recover )
342
+ missing_names = [ v .name for v in vars_to_recover if v not in self . marginalized_rvs ]
332
343
if missing_names :
333
344
raise ValueError (f"Unrecognized var_names: { missing_names } " )
334
345
346
+ if return_samples and random_seed is not None :
347
+ seeds = _get_seeds_per_chain (random_seed , len (vars_to_recover ))
348
+ else :
349
+ seeds = [None ] * len (vars_to_recover )
350
+
335
351
posterior = idata .posterior
336
352
337
353
# Remove Deterministics
@@ -351,9 +367,8 @@ def transform_input(inputs):
351
367
posterior_pts = [transform_input (vs ) for vs in posterior_pts ]
352
368
353
369
rv_dict = {}
354
- rv_dims_dict = {}
355
370
356
- for rv in vars_to_recover :
371
+ for seed , rv in zip ( seeds , vars_to_recover ) :
357
372
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
358
373
if not isinstance (rv .owner .op , supported_dists ):
359
374
raise NotImplementedError (
@@ -365,7 +380,21 @@ def transform_input(inputs):
365
380
rv = m .vars_to_clone [rv ]
366
381
m .unmarginalize ([rv ])
367
382
dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
368
- joint_logp = m .logp (vars = dependent_vars + [rv ])
383
+ joint_logps = m .logp (vars = dependent_vars + [rv ], sum = False )
384
+
385
+ marginalized_value = m .rvs_to_values [rv ]
386
+ other_values = [v for v in m .value_vars if v is not marginalized_value ]
387
+
388
+ # Handle batch dims for marginalized value and its dependent RVs
389
+ joint_logp = joint_logps [- 1 ]
390
+ for dv in joint_logps [:- 1 ]:
391
+ dbcast = dv .type .broadcastable
392
+ mbcast = marginalized_value .type .broadcastable
393
+ mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
394
+ values_axis_bcast = [
395
+ i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
396
+ ]
397
+ joint_logp += dv .sum (values_axis_bcast )
369
398
370
399
rv_shape = constant_fold (tuple (rv .shape ))
371
400
rv_domain = get_domain_of_finite_discrete_rv (rv )
@@ -379,27 +408,28 @@ def transform_input(inputs):
379
408
0 ,
380
409
)
381
410
382
- marginalized_value = m .rvs_to_values [rv ]
383
- other_values = [v for v in m .value_vars if v is not marginalized_value ]
384
-
385
411
joint_logps = vectorize_graph (
386
412
joint_logp ,
387
413
replace = {marginalized_value : rv_domain_tensor },
388
414
)
415
+ joint_logps = pt .moveaxis (joint_logps , 0 , - 1 )
389
416
390
417
rv_loglike_fn = None
418
+ joint_logps_norm = log_softmax (joint_logps , axis = 0 )
391
419
if return_samples :
392
420
sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
393
421
rv_loglike_fn = compile_pymc (
394
422
inputs = other_values ,
395
- outputs = [log_softmax ( joint_logps , axis = 0 ) , sample_rv_outs ],
423
+ outputs = [joint_logps_norm , sample_rv_outs ],
396
424
on_unused_input = "ignore" ,
425
+ random_seed = seed ,
397
426
)
398
427
else :
399
428
rv_loglike_fn = compile_pymc (
400
429
inputs = other_values ,
401
- outputs = log_softmax ( joint_logps , axis = 0 ) ,
430
+ outputs = joint_logps_norm ,
402
431
on_unused_input = "ignore" ,
432
+ random_seed = seed ,
403
433
)
404
434
405
435
logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
@@ -409,18 +439,16 @@ def transform_input(inputs):
409
439
if return_samples :
410
440
logps , samples = zip (* logvs )
411
441
logps = np .array (logps )
412
- rv_dict [rv .name ] = np .reshape (
413
- samples , tuple (len (coord ) for coord in stacked_dims .values ())
442
+ samples = np .array (samples )
443
+ rv_dict [rv .name ] = samples .reshape (
444
+ tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
414
445
)
415
- rv_dims_dict [rv .name ] = sample_dims
416
446
else :
417
447
logps = np .array (logvs )
418
448
419
- rv_dict ["lp_" + rv .name ] = np .reshape (
420
- logps ,
449
+ rv_dict ["lp_" + rv .name ] = logps .reshape (
421
450
tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
422
451
)
423
- rv_dims_dict ["lp_" + rv .name ] = sample_dims + ("lp_" + rv .name + "_dims" ,)
424
452
425
453
coords , dims = coords_and_dims_for_inferencedata (self )
426
454
rv_dataset = dict_to_dataset (
@@ -433,8 +461,7 @@ def transform_input(inputs):
433
461
)
434
462
435
463
if extend_inferencedata :
436
- rv_dict = {k : (rv_dims_dict [k ], v ) for (k , v ) in rv_dict .items ()}
437
- idata = idata .posterior .assign (** rv_dict )
464
+ idata .posterior = idata .posterior .assign (rv_dataset )
438
465
return idata
439
466
else :
440
467
return rv_dataset
0 commit comments