We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 13211ab commit 5ead708Copy full SHA for 5ead708
pymc/sampling_jax.py
@@ -114,7 +114,6 @@ def sample_numpyro_nuts(
114
var_names = model.unobserved_value_vars
115
116
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
117
- inputs = [model.rvs_to_values[i] for i in model.free_RVs]
118
119
tic1 = pd.Timestamp.now()
120
print("Compiling...", file=sys.stdout)
@@ -164,7 +163,7 @@ def sample_numpyro_nuts(
164
163
print("Transforming variables...", file=sys.stdout)
165
mcmc_samples = {}
166
for v in vars_to_sample:
167
- fgraph = FunctionGraph(inputs, [v], clone=False)
+ fgraph = FunctionGraph(model.value_vars, [v], clone=False)
168
jax_fn = jax_funcify(fgraph)
169
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
170
mcmc_samples[v.name] = result
0 commit comments