Skip to content

Commit dd4109c

Browse files
Do not use shared variables as inputs during prior/posterior sampling
1 parent 01189d3 commit dd4109c

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pymc3/sampling.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import packaging
3333
import xarray
3434

35+
from aesara.tensor.sharedvar import SharedVariable
3536
from arviz import InferenceData
3637
from fastprogress.fastprogress import progress_bar
3738

@@ -1730,15 +1731,21 @@ def sample_posterior_predictive(
17301731
inputs_and_names = [
17311732
(rv, rv.name)
17321733
for rv in rv_ancestors(vars_to_sample, walk_past_rvs=True)
1733-
if rv not in vars_to_sample and rv in model.named_vars.values()
1734+
if rv not in vars_to_sample
1735+
and rv in model.named_vars.values()
1736+
and not isinstance(rv, SharedVariable)
17341737
]
17351738
if inputs_and_names:
17361739
inputs, input_names = zip(*inputs_and_names)
17371740
else:
17381741
inputs, input_names = [], []
17391742
else:
17401743
output_names = [v.name for v in vars_to_sample if v.name is not None]
1741-
input_names = [n for n in _trace.varnames if n not in output_names]
1744+
input_names = [
1745+
n
1746+
for n in _trace.varnames
1747+
if n not in output_names and not isinstance(model[n], SharedVariable)
1748+
]
17421749
inputs = [model[n] for n in input_names]
17431750

17441751
if size is not None:
@@ -1987,7 +1994,7 @@ def sample_prior_predictive(
19871994
names = get_default_varnames(vars_, include_transformed=False)
19881995

19891996
vars_to_sample = [model[name] for name in names]
1990-
inputs = [i for i in inputvars(vars_to_sample)]
1997+
inputs = [i for i in inputvars(vars_to_sample) if not isinstance(i, SharedVariable)]
19911998
sampler_fn = aesara.function(
19921999
inputs,
19932000
vars_to_sample,

0 commit comments

Comments
 (0)