Skip to content

Commit 5ead708

Browse files
committed
Use value_vars directly
1 parent 13211ab commit 5ead708

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

pymc/sampling_jax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def sample_numpyro_nuts(
114114
var_names = model.unobserved_value_vars
115115

116116
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
117-
inputs = [model.rvs_to_values[i] for i in model.free_RVs]
118117

119118
tic1 = pd.Timestamp.now()
120119
print("Compiling...", file=sys.stdout)
@@ -164,7 +163,7 @@ def sample_numpyro_nuts(
164163
print("Transforming variables...", file=sys.stdout)
165164
mcmc_samples = {}
166165
for v in vars_to_sample:
167-
fgraph = FunctionGraph(inputs, [v], clone=False)
166+
fgraph = FunctionGraph(model.value_vars, [v], clone=False)
168167
jax_fn = jax_funcify(fgraph)
169168
result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0]
170169
mcmc_samples[v.name] = result

0 commit comments

Comments
 (0)