@@ -132,13 +132,13 @@ def _sample_stats_to_xarray(posterior):
132
132
return data
133
133
134
134
135
- def _get_log_likelihood (model : Model , samples ) -> Dict :
135
+ def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
136
136
"""Compute log-likelihood for all observations"""
137
137
data = {}
138
138
for v in model .observed_RVs :
139
139
v_elemwise_logpt = model .logpt (v , sum = False )
140
140
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = v_elemwise_logpt )
141
- result = jax .jit (jax .vmap (jax .vmap (jax_fn )))(* samples )[0 ]
141
+ result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = backend )(* samples )[0 ]
142
142
data [v .name ] = result
143
143
return data
144
144
@@ -226,6 +226,7 @@ def sample_blackjax_nuts(
226
226
var_names = None ,
227
227
keep_untransformed = False ,
228
228
chain_method = "parallel" ,
229
+ postprocessing_backend = None ,
229
230
idata_kwargs = None ,
230
231
):
231
232
"""
@@ -255,6 +256,8 @@ def sample_blackjax_nuts(
255
256
Include untransformed variables in the posterior samples. Defaults to False.
256
257
chain_method : str, default "parallel"
257
258
Specify how samples should be drawn. The choices include "parallel", and "vectorized".
259
+ postprocessing_backend : str, optional
260
+ Specify how postprocessing should be computed. gpu or cpu
258
261
idata_kwargs : dict, optional
259
262
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
260
263
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
@@ -341,7 +344,9 @@ def sample_blackjax_nuts(
341
344
mcmc_samples = {}
342
345
for v in vars_to_sample :
343
346
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [v ])
344
- result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
347
+ result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = postprocessing_backend )(
348
+ * raw_mcmc_samples
349
+ )[0 ]
345
350
mcmc_samples [v .name ] = result
346
351
347
352
tic4 = datetime .now ()
@@ -353,7 +358,9 @@ def sample_blackjax_nuts(
353
358
idata_kwargs = idata_kwargs .copy ()
354
359
355
360
if idata_kwargs .pop ("log_likelihood" , True ):
356
- log_likelihood = _get_log_likelihood (model , raw_mcmc_samples )
361
+ log_likelihood = _get_log_likelihood (
362
+ model , raw_mcmc_samples , backend = postprocessing_backend
363
+ )
357
364
else :
358
365
log_likelihood = None
359
366
@@ -387,6 +394,7 @@ def sample_numpyro_nuts(
387
394
progress_bar : bool = True ,
388
395
keep_untransformed : bool = False ,
389
396
chain_method : str = "parallel" ,
397
+ postprocessing_backend : str = None ,
390
398
idata_kwargs : Optional [Dict ] = None ,
391
399
nuts_kwargs : Optional [Dict ] = None ,
392
400
):
@@ -421,6 +429,8 @@ def sample_numpyro_nuts(
421
429
Include untransformed variables in the posterior samples. Defaults to False.
422
430
chain_method : str, default "parallel"
423
431
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
432
+ postprocessing_backend : Optional[str]
433
+ Specify how postprocessing should be computed. gpu or cpu
424
434
idata_kwargs : dict, optional
425
435
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
426
436
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
@@ -525,7 +535,9 @@ def sample_numpyro_nuts(
525
535
mcmc_samples = {}
526
536
for v in vars_to_sample :
527
537
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [v ])
528
- result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
538
+ result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = postprocessing_backend )(
539
+ * raw_mcmc_samples
540
+ )[0 ]
529
541
mcmc_samples [v .name ] = result
530
542
531
543
tic4 = datetime .now ()
@@ -537,7 +549,9 @@ def sample_numpyro_nuts(
537
549
idata_kwargs = idata_kwargs .copy ()
538
550
539
551
if idata_kwargs .pop ("log_likelihood" , True ):
540
- log_likelihood = _get_log_likelihood (model , raw_mcmc_samples )
552
+ log_likelihood = _get_log_likelihood (
553
+ model , raw_mcmc_samples , backend = postprocessing_backend
554
+ )
541
555
else :
542
556
log_likelihood = None
543
557
0 commit comments