@@ -134,13 +134,10 @@ def _sample_stats_to_xarray(posterior):
134
134
135
135
def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
136
136
"""Compute log-likelihood for all observations"""
137
- data = {}
138
- for v in model .observed_RVs :
139
- v_elemwise_logp = model .logp (v , sum = False )
140
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = v_elemwise_logp )
141
- result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = backend )(* samples )[0 ]
142
- data [v .name ] = result
143
- return data
137
+ elemwise_logp = model .logp (model .observed_RVs , sum = False )
138
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
139
+ result = jax .vmap (jax .vmap (jax_fn ))(* jax .device_put (samples , jax .devices (backend )[0 ]))
140
+ return {v .name : r for v , r in zip (model .observed_RVs , result )}
144
141
145
142
146
143
def _get_batched_jittered_initial_points (
@@ -339,13 +336,11 @@ def sample_blackjax_nuts(
339
336
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
340
337
341
338
print ("Transforming variables..." , file = sys .stdout )
342
- mcmc_samples = {}
343
- for v in vars_to_sample :
344
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [v ])
345
- result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = postprocessing_backend )(
346
- * raw_mcmc_samples
347
- )[0 ]
348
- mcmc_samples [v .name ] = result
339
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
340
+ result = jax .vmap (jax .vmap (jax_fn ))(
341
+ * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
342
+ )
343
+ mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
349
344
350
345
tic4 = datetime .now ()
351
346
print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
@@ -355,10 +350,14 @@ def sample_blackjax_nuts(
355
350
else :
356
351
idata_kwargs = idata_kwargs .copy ()
357
352
358
- if idata_kwargs .pop ("log_likelihood" , True ):
353
+ if idata_kwargs .pop ("log_likelihood" , bool (model .observed_RVs )):
354
+ tic5 = datetime .now ()
355
+ print ("Computing Log Likelihood..." , file = sys .stdout )
359
356
log_likelihood = _get_log_likelihood (
360
357
model , raw_mcmc_samples , backend = postprocessing_backend
361
358
)
359
+ tic6 = datetime .now ()
360
+ print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
362
361
else :
363
362
log_likelihood = None
364
363
@@ -531,13 +530,11 @@ def sample_numpyro_nuts(
531
530
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
532
531
533
532
print ("Transforming variables..." , file = sys .stdout )
534
- mcmc_samples = {}
535
- for v in vars_to_sample :
536
- jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [v ])
537
- result = jax .jit (jax .vmap (jax .vmap (jax_fn )), backend = postprocessing_backend )(
538
- * raw_mcmc_samples
539
- )[0 ]
540
- mcmc_samples [v .name ] = result
533
+ jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
534
+ result = jax .vmap (jax .vmap (jax_fn ))(
535
+ * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
536
+ )
537
+ mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
541
538
542
539
tic4 = datetime .now ()
543
540
print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
@@ -547,10 +544,14 @@ def sample_numpyro_nuts(
547
544
else :
548
545
idata_kwargs = idata_kwargs .copy ()
549
546
550
- if idata_kwargs .pop ("log_likelihood" , True ):
547
+ if idata_kwargs .pop ("log_likelihood" , bool (model .observed_RVs )):
548
+ tic5 = datetime .now ()
549
+ print ("Computing Log Likelihood..." , file = sys .stdout )
551
550
log_likelihood = _get_log_likelihood (
552
551
model , raw_mcmc_samples , backend = postprocessing_backend
553
552
)
553
+ tic6 = datetime .now ()
554
+ print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
554
555
else :
555
556
log_likelihood = None
556
557
0 commit comments