Skip to content

Commit 3a2a765

Browse files
OriolAbrilaloctavodia
authored andcommitted
fix default prior variable names (#3591)
* fix default prior variable names * update docs and add test on pm.Data not being in prior * Add model.potentials to prior_vars
1 parent 64d2b88 commit 3a2a765

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

pymc3/sampling.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ def sample_prior_predictive(samples=500,
12921292
samples. *DEPRECATED* - Use ``var_names`` argument instead.
12931293
var_names : Iterable[str]
12941294
A list of names of variables for which to compute the posterior predictive
1295-
samples. Defaults to ``model.named_vars``.
1295+
samples. Defaults to both observed and unobserved RVs.
12961296
random_seed : int
12971297
Seed for the random number generator.
12981298
@@ -1305,8 +1305,13 @@ def sample_prior_predictive(samples=500,
13051305
model = modelcontext(model)
13061306

13071307
if vars is None and var_names is None:
1308-
vars = set(model.named_vars.keys())
1309-
vars_ = model.named_vars
1308+
prior_pred_vars = model.observed_RVs
1309+
prior_vars = (
1310+
get_default_varnames(model.unobserved_RVs, include_transformed=True) +
1311+
model.potentials
1312+
)
1313+
vars_ = [var.name for var in prior_vars + prior_pred_vars]
1314+
vars = set(vars_)
13101315
elif vars is None:
13111316
vars = var_names
13121317
vars_ = vars

pymc3/tests/test_sampling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,12 +505,14 @@ def test_ignores_observed(self):
505505
observed = np.random.normal(10, 1, size=200)
506506
with pm.Model():
507507
# Use a prior that's way off to show we're ignoring the observed variables
508+
observed_data = pm.Data("observed_data", observed)
508509
mu = pm.Normal("mu", mu=-100, sigma=1)
509510
positive_mu = pm.Deterministic("positive_mu", np.abs(mu))
510511
z = -1 - positive_mu
511-
pm.Normal("x_obs", mu=z, sigma=1, observed=observed)
512+
pm.Normal("x_obs", mu=z, sigma=1, observed=observed_data)
512513
prior = pm.sample_prior_predictive()
513514

515+
assert "observed_data" not in prior
514516
assert (prior["mu"] < 90).all()
515517
assert (prior["positive_mu"] > 90).all()
516518
assert (prior["x_obs"] < 90).all()

0 commit comments

Comments
 (0)