@@ -279,7 +279,7 @@ def unmarginalize(self, rvs_to_unmarginalize):
279
279
self .register_rv (rv , name = rv .name )
280
280
281
281
def recover_marginals (
282
- self , idata , var_names = None , include_samples = False , extend_inferencedata = True
282
+ self , idata , var_names = None , return_samples = False , extend_inferencedata = True
283
283
):
284
284
"""Computes unnormalized posterior probabilities of marginalized variables
285
285
conditioned on parameters of the model given InferenceData with posterior group
@@ -295,8 +295,8 @@ def recover_marginals(
295
295
InferenceData with posterior group
296
296
var_names : sequence of str, optional
297
297
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
298
- include_samples : bool, default False
299
- Include samples of the marginalized variables
298
+ return_samples : bool, default False
299
+ If True, also return samples of the marginalized variables
300
300
extend_inferencedata : bool, default True
301
301
Whether to extend the original InferenceData or return a new one
302
302
@@ -360,7 +360,7 @@ def transform_input(inputs):
360
360
)
361
361
362
362
rv_loglike_fn = None
363
- if include_samples :
363
+ if return_samples :
364
364
sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
365
365
rv_loglike_fn = compile_pymc (
366
366
inputs = other_values ,
@@ -376,7 +376,7 @@ def transform_input(inputs):
376
376
377
377
logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
378
378
379
- if include_samples :
379
+ if return_samples :
380
380
logps , samples = zip (* logvs )
381
381
logps = np .array (logps )
382
382
rv_dict [rv .name ] = np .reshape (
0 commit comments