Skip to content

Commit 5352798

Browse files
authored
Allow for passing of backend and gradient_backend to nutpie (#7535)
* Allow for passing of backend and gradient_backend to nutpie * Extract nutpie compiler args explicitly
1 parent c61e9cd commit 5352798

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

pymc/sampling/mcmc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,14 @@ def _sample_external_nuts(
305305
"`var_names` are currently ignored by the nutpie sampler",
306306
UserWarning,
307307
)
308-
compiled_model = nutpie.compile_pymc_model(model)
308+
compile_kwargs = {}
309+
for kwarg in ("backend", "gradient_backend"):
310+
if kwarg in nuts_sampler_kwargs:
311+
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
312+
compiled_model = nutpie.compile_pymc_model(
313+
model,
314+
**compile_kwargs,
315+
)
309316
t_start = time.time()
310317
idata = nutpie.sample(
311318
compiled_model,

0 commit comments

Comments
 (0)