Skip to content

Commit ca0c28b

Browse files
committed
🔥 remove vars from sample_posterior_predictive, other random refactorings
1 parent dbcc49e commit ca0c28b

File tree

3 files changed

+12
-32
lines changed

3 files changed

+12
-32
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This is the first release to support Python3.9 and to drop Python3.6.
88
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
99
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
1010
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
11-
11+
- In `sample_posterior_predictive` the `vars` kwarg was removed in favor of `var_names` (see [#4343](https://github.com/pymc-devs/pymc3/pull/4343)).
1212

1313
## PyMC3 3.10.0 (7 December 2020)
1414

pymc3/sampling.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from arviz import InferenceData
3535
from fastprogress.fastprogress import progress_bar
36-
from theano.tensor import Tensor
3736

3837
import pymc3 as pm
3938

@@ -561,12 +560,11 @@ def sample(
561560
_log.debug("Pickling error:", exec_info=True)
562561
parallel = False
563562
except AttributeError as e:
564-
if str(e).startswith("AttributeError: Can't pickle"):
565-
_log.warning("Could not pickle model, sampling singlethreaded.")
566-
_log.debug("Pickling error:", exec_info=True)
567-
parallel = False
568-
else:
563+
if not str(e).startswith("AttributeError: Can't pickle"):
569564
raise
565+
_log.warning("Could not pickle model, sampling singlethreaded.")
566+
_log.debug("Pickling error:", exec_info=True)
567+
parallel = False
570568
if not parallel:
571569
if has_population_samplers:
572570
has_demcmc = np.any(
@@ -1602,7 +1600,6 @@ def sample_posterior_predictive(
16021600
trace,
16031601
samples: Optional[int] = None,
16041602
model: Optional[Model] = None,
1605-
vars: Optional[Iterable[Tensor]] = None,
16061603
var_names: Optional[List[str]] = None,
16071604
size: Optional[int] = None,
16081605
keep_size: Optional[bool] = False,
@@ -1696,14 +1693,9 @@ def sample_posterior_predictive(
16961693
model = modelcontext(model)
16971694

16981695
if var_names is not None:
1699-
if vars is not None:
1700-
raise IncorrectArgumentsError("Should not specify both vars and var_names arguments.")
1701-
else:
1702-
vars = [model[x] for x in var_names]
1703-
elif vars is not None: # var_names is None, and vars is not.
1704-
warnings.warn("vars argument is deprecated in favor of var_names.", DeprecationWarning)
1705-
if vars is None:
1706-
vars = model.observed_RVs
1696+
vars_ = [model[x] for x in var_names]
1697+
else:
1698+
vars_ = model.observed_RVs
17071699

17081700
if random_seed is not None:
17091701
np.random.seed(random_seed)
@@ -1729,8 +1721,8 @@ def sample_posterior_predictive(
17291721
else:
17301722
param = _trace[idx % len_trace]
17311723

1732-
values = draw_values(vars, point=param, size=size)
1733-
for k, v in zip(vars, values):
1724+
values = draw_values(vars_, point=param, size=size)
1725+
for k, v in zip(vars_, values):
17341726
ppc_trace_t.insert(k.name, v, idx)
17351727
except KeyboardInterrupt:
17361728
pass
@@ -1809,7 +1801,7 @@ def sample_posterior_predictive_w(
18091801
raise ValueError("The number of models and weights should be the same")
18101802

18111803
length_morv = len(models[0].observed_RVs)
1812-
if not all(len(i.observed_RVs) == length_morv for i in models):
1804+
if any(len(i.observed_RVs) != length_morv for i in models):
18131805
raise ValueError("The number of observed RVs should be the same for all models")
18141806

18151807
weights = np.asarray(weights)

pymc3/tests/test_sampling.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,7 @@ def test_normal_scalar(self):
406406
ppc0 = pm.sample_posterior_predictive([model.test_point], samples=10)
407407
ppc0 = pm.fast_sample_posterior_predictive([model.test_point], samples=10)
408408
# deprecated argument is not introduced to fast version [2019/08/20:rpg]
409-
with pytest.warns(DeprecationWarning):
410-
ppc = pm.sample_posterior_predictive(trace, vars=[a])
409+
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
411410
# test empty ppc
412411
ppc = pm.sample_posterior_predictive(trace, var_names=[])
413412
assert len(ppc) == 0
@@ -518,8 +517,6 @@ def test_exceptions(self, caplog):
518517
# Not for fast_sample_posterior_predictive
519518
with pytest.raises(IncorrectArgumentsError):
520519
ppc = pm.sample_posterior_predictive(trace, size=4, keep_size=True)
521-
with pytest.raises(IncorrectArgumentsError):
522-
ppc = pm.sample_posterior_predictive(trace, vars=[a], var_names=["a"])
523520
# test wrong type argument
524521
bad_trace = {"mu": stats.norm.rvs(size=1000)}
525522
with pytest.raises(TypeError):
@@ -653,16 +650,7 @@ def test_deterministic_of_observed(self):
653650

654651
trace = pm.sample(100, chains=nchains)
655652
np.random.seed(0)
656-
with pytest.warns(DeprecationWarning):
657-
ppc = pm.sample_posterior_predictive(
658-
model=model,
659-
trace=trace,
660-
samples=len(trace) * nchains,
661-
vars=(model.deterministics + model.basic_RVs),
662-
)
663-
664653
rtol = 1e-5 if theano.config.floatX == "float64" else 1e-4
665-
npt.assert_allclose(ppc["in_1"] + ppc["in_2"], ppc["out"], rtol=rtol)
666654

667655
np.random.seed(0)
668656
ppc = pm.sample_posterior_predictive(

0 commit comments

Comments
 (0)