11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import logging
14
15
import os
15
16
import re
16
- import sys
17
17
18
18
from datetime import datetime
19
19
from functools import partial
53
53
get_default_varnames ,
54
54
)
55
55
56
+ logger = logging .getLogger (__name__ )
57
+
56
58
xla_flags_env = os .getenv ("XLA_FLAGS" , "" )
57
59
xla_flags = re .sub (r"--xla_force_host_platform_device_count=.+\s" , "" , xla_flags_env ).split ()
58
60
os .environ ["XLA_FLAGS" ] = " " .join ([f"--xla_force_host_platform_device_count={ 100 } " ] + xla_flags )
@@ -289,40 +291,46 @@ def _update_coords_and_dims(
289
291
dims .update (idata_kwargs .pop ("dims" ))
290
292
291
293
292
- @partial (jax .jit , static_argnums = (2 , 3 , 4 , 5 , 6 ))
293
294
def _blackjax_inference_loop (
294
- seed ,
295
- init_position ,
296
- logprob_fn ,
297
- draws ,
298
- tune ,
299
- target_accept ,
300
- algorithm = None ,
295
+ seed , init_position , logprob_fn , draws , tune , target_accept , ** adaptation_kwargs
301
296
):
302
297
import blackjax
303
298
304
- if algorithm is None :
299
+ algorithm_name = adaptation_kwargs .pop ("algorithm" , "nuts" )
300
+ if algorithm_name == "nuts" :
305
301
algorithm = blackjax .nuts
302
+ elif algorithm_name == "hmc" :
303
+ algorithm = blackjax .hmc
304
+ else :
305
+ raise ValueError ("Only supporting 'nuts' or 'hmc' as algorithm to draw samples." )
306
306
307
307
adapt = blackjax .window_adaptation (
308
308
algorithm = algorithm ,
309
309
logdensity_fn = logprob_fn ,
310
310
target_acceptance_rate = target_accept ,
311
+ ** adaptation_kwargs ,
311
312
)
312
313
(last_state , tuned_params ), _ = adapt .run (seed , init_position , num_steps = tune )
313
314
kernel = algorithm (logprob_fn , ** tuned_params ).step
314
315
315
- def inference_loop ( rng_key , initial_state ):
316
- def one_step ( state , rng_key ):
317
- state , info = kernel (rng_key , state )
318
- return state , (state , info )
316
+ def _one_step ( state , xs ):
317
+ _ , rng_key = xs
318
+ state , info = kernel (rng_key , state )
319
+ return state , (state , info )
319
320
320
- keys = jax .random .split (rng_key , draws )
321
- _ , (states , infos ) = jax .lax .scan (one_step , initial_state , keys )
321
+ progress_bar = adaptation_kwargs .pop ("progress_bar" , False )
322
+ if progress_bar :
323
+ from blackjax .progress_bar import progress_bar_scan
324
+
325
+ logger .info ("Sample with tuned parameters" )
326
+ one_step = jax .jit (progress_bar_scan (draws )(_one_step ))
327
+ else :
328
+ one_step = jax .jit (_one_step )
322
329
323
- return states , infos
330
+ keys = jax .random .split (seed , draws )
331
+ _ , (states , infos ) = jax .lax .scan (one_step , last_state , (jnp .arange (draws ), keys ))
324
332
325
- return inference_loop ( seed , last_state )
333
+ return states , infos
326
334
327
335
328
336
def sample_blackjax_nuts (
@@ -334,11 +342,13 @@ def sample_blackjax_nuts(
334
342
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
335
343
model : Optional [Model ] = None ,
336
344
var_names : Optional [Sequence [str ]] = None ,
345
+ progress_bar : bool = False ,
337
346
keep_untransformed : bool = False ,
338
347
chain_method : str = "parallel" ,
339
348
postprocessing_backend : Optional [Literal ["cpu" , "gpu" ]] = None ,
340
349
postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
341
350
idata_kwargs : Optional [Dict [str , Any ]] = None ,
351
+ adaptation_kwargs : Optional [Dict [str , Any ]] = None ,
342
352
postprocessing_chunks = None , # deprecated
343
353
) -> az .InferenceData :
344
354
"""
@@ -415,7 +425,7 @@ def sample_blackjax_nuts(
415
425
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
416
426
417
427
tic1 = datetime .now ()
418
- print ("Compiling..." , file = sys . stdout )
428
+ logger . info ("Compiling..." )
419
429
420
430
init_params = _get_batched_jittered_initial_points (
421
431
model = model ,
@@ -432,36 +442,49 @@ def sample_blackjax_nuts(
432
442
seed = jax .random .PRNGKey (random_seed )
433
443
keys = jax .random .split (seed , chains )
434
444
435
- get_posterior_samples = partial (
436
- _blackjax_inference_loop ,
437
- logprob_fn = logprob_fn ,
438
- tune = tune ,
439
- draws = draws ,
440
- target_accept = target_accept ,
441
- )
442
-
443
- tic2 = datetime .now ()
444
- print ("Compilation time = " , tic2 - tic1 , file = sys .stdout )
445
-
446
- print ("Sampling..." , file = sys .stdout )
445
+ if adaptation_kwargs is None :
446
+ adaptation_kwargs = {}
447
447
448
448
# Adapted from numpyro
449
449
if chain_method == "parallel" :
450
450
map_fn = jax .pmap
451
+ if progress_bar :
452
+ import warnings
453
+
454
+ warnings .warn (
455
+ "BlackJax currently only display progress bar correctly under "
456
+ "`chain_method == 'vectorized'`. Setting `progressbar=False`."
457
+ )
458
+ progress_bar = False
451
459
elif chain_method == "vectorized" :
452
460
map_fn = jax .vmap
453
461
else :
454
462
raise ValueError (
455
463
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
456
464
)
457
465
466
+ adaptation_kwargs ["progress_bar" ] = progress_bar
467
+ get_posterior_samples = partial (
468
+ _blackjax_inference_loop ,
469
+ logprob_fn = logprob_fn ,
470
+ tune = tune ,
471
+ draws = draws ,
472
+ target_accept = target_accept ,
473
+ ** adaptation_kwargs ,
474
+ )
475
+
476
+ tic2 = datetime .now ()
477
+ logger .info (f"Compilation time = { tic2 - tic1 } " )
478
+
479
+ logger .info ("Sampling..." )
480
+
458
481
states , stats = map_fn (get_posterior_samples )(keys , init_params )
459
482
raw_mcmc_samples = states .position
460
- potential_energy = states .logdensity
483
+ potential_energy = states .logdensity . block_until_ready ()
461
484
tic3 = datetime .now ()
462
- print ( "Sampling time = " , tic3 - tic2 , file = sys . stdout )
485
+ logger . info ( f "Sampling time = { tic3 - tic2 } " )
463
486
464
- print ("Transforming variables..." , file = sys . stdout )
487
+ logger . info ("Transforming variables..." )
465
488
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
466
489
result = _postprocess_samples (
467
490
jax_fn ,
@@ -472,7 +495,7 @@ def sample_blackjax_nuts(
472
495
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
473
496
mcmc_stats = _blackjax_stats_to_dict (stats , potential_energy )
474
497
tic4 = datetime .now ()
475
- print ( "Transformation time = " , tic4 - tic3 , file = sys . stdout )
498
+ logger . info ( f "Transformation time = { tic4 - tic3 } " )
476
499
477
500
if idata_kwargs is None :
478
501
idata_kwargs = {}
@@ -481,15 +504,15 @@ def sample_blackjax_nuts(
481
504
482
505
if idata_kwargs .pop ("log_likelihood" , False ):
483
506
tic5 = datetime .now ()
484
- print ( "Computing Log Likelihood..." , file = sys . stdout )
507
+ logger . info ( f "Computing Log Likelihood..." )
485
508
log_likelihood = _get_log_likelihood (
486
509
model ,
487
510
raw_mcmc_samples ,
488
511
backend = postprocessing_backend ,
489
512
postprocessing_vectorize = postprocessing_vectorize ,
490
513
)
491
514
tic6 = datetime .now ()
492
- print ( "Log Likelihood time = " , tic6 - tic5 , file = sys . stdout )
515
+ logger . info ( f "Log Likelihood time = { tic6 - tic5 } " )
493
516
else :
494
517
log_likelihood = None
495
518
@@ -634,7 +657,7 @@ def sample_numpyro_nuts(
634
657
(random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
635
658
636
659
tic1 = datetime .now ()
637
- print ("Compiling..." , file = sys . stdout )
660
+ logger . info ("Compiling..." )
638
661
639
662
init_params = _get_batched_jittered_initial_points (
640
663
model = model ,
@@ -663,9 +686,9 @@ def sample_numpyro_nuts(
663
686
)
664
687
665
688
tic2 = datetime .now ()
666
- print ( "Compilation time = " , tic2 - tic1 , file = sys . stdout )
689
+ logger . info ( f "Compilation time = { tic2 - tic1 } " )
667
690
668
- print ("Sampling..." , file = sys . stdout )
691
+ logger . info ("Sampling..." )
669
692
670
693
map_seed = jax .random .PRNGKey (random_seed )
671
694
if chains > 1 :
@@ -687,9 +710,9 @@ def sample_numpyro_nuts(
687
710
raw_mcmc_samples = pmap_numpyro .get_samples (group_by_chain = True )
688
711
689
712
tic3 = datetime .now ()
690
- print ( "Sampling time = " , tic3 - tic2 , file = sys . stdout )
713
+ logger . info ( f "Sampling time = { tic3 - tic2 } " )
691
714
692
- print ("Transforming variables..." , file = sys . stdout )
715
+ logger . info ("Transforming variables..." )
693
716
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
694
717
result = _postprocess_samples (
695
718
jax_fn ,
@@ -700,7 +723,7 @@ def sample_numpyro_nuts(
700
723
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
701
724
702
725
tic4 = datetime .now ()
703
- print ( "Transformation time = " , tic4 - tic3 , file = sys . stdout )
726
+ logger . info ( f "Transformation time = { tic4 - tic3 } " )
704
727
705
728
if idata_kwargs is None :
706
729
idata_kwargs = {}
@@ -709,15 +732,17 @@ def sample_numpyro_nuts(
709
732
710
733
if idata_kwargs .pop ("log_likelihood" , False ):
711
734
tic5 = datetime .now ()
712
- print ( "Computing Log Likelihood..." , file = sys . stdout )
735
+ logger . info ( f "Computing Log Likelihood..." )
713
736
log_likelihood = _get_log_likelihood (
714
737
model ,
715
738
raw_mcmc_samples ,
716
739
backend = postprocessing_backend ,
717
740
postprocessing_vectorize = postprocessing_vectorize ,
718
741
)
719
742
tic6 = datetime .now ()
720
- print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
743
+ logger .info (
744
+ f"Log Likelihood time = { tic6 - tic5 } " ,
745
+ )
721
746
else :
722
747
log_likelihood = None
723
748
0 commit comments