Skip to content

Commit 2bdae64

Browse files
committed
Change include_samples to return_samples for more consistent API
1 parent a7f63e2 commit 2bdae64

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pymc_experimental/model/marginal_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def unmarginalize(self, rvs_to_unmarginalize):
279279
self.register_rv(rv, name=rv.name)
280280

281281
def recover_marginals(
282-
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
282+
self, idata, var_names=None, return_samples=False, extend_inferencedata=True
283283
):
284284
"""Computes unnormalized posterior probabilities of marginalized variables
285285
conditioned on parameters of the model given InferenceData with posterior group
@@ -295,8 +295,8 @@ def recover_marginals(
295295
InferenceData with posterior group
296296
var_names : sequence of str, optional
297297
List of Observed variable names for which to compute log_likelihood. Defaults to all observed variables
298-
include_samples : bool, default False
299-
Include samples of the marginalized variables
298+
return_samples : bool, default False
299+
If True, also return samples of the marginalized variables
300300
extend_inferencedata : bool, default True
301301
Whether to extend the original InferenceData or return a new one
302302
@@ -360,7 +360,7 @@ def transform_input(inputs):
360360
)
361361

362362
rv_loglike_fn = None
363-
if include_samples:
363+
if return_samples:
364364
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
365365
rv_loglike_fn = compile_pymc(
366366
inputs=other_values,
@@ -376,7 +376,7 @@ def transform_input(inputs):
376376

377377
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
378378

379-
if include_samples:
379+
if return_samples:
380380
logps, samples = zip(*logvs)
381381
logps = np.array(logps)
382382
rv_dict[rv.name] = np.reshape(

pymc_experimental/tests/model/test_marginal_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_recover_marginals_basic():
274274
)
275275
idata = InferenceData(posterior=dict_to_dataset(prior))
276276

277-
idata = m.recover_marginals(idata, include_samples=True)
277+
idata = m.recover_marginals(idata, return_samples=True)
278278
assert "k" in idata
279279
assert "lp_k" in idata
280280
assert idata.k.shape == idata.y.shape
@@ -313,7 +313,7 @@ def test_nested_recover_marginals():
313313
)
314314
idata = InferenceData(posterior=dict_to_dataset(prior))
315315

316-
idata = m.recover_marginals(idata, include_samples=True)
316+
idata = m.recover_marginals(idata, return_samples=True)
317317
assert "idx" in idata
318318
assert "lp_idx" in idata
319319
assert idata.idx.shape == idata.y.shape

0 commit comments

Comments
 (0)