17
17
18
18
from datetime import datetime
19
19
from functools import partial
20
- from typing import Any , Callable , Dict , List , Optional , Sequence , Union
20
+ from typing import Any , Callable , Dict , List , Literal , Optional , Sequence , Union
21
21
22
22
import arviz as az
23
23
import jax
26
26
import pytensor .tensor as pt
27
27
28
28
from arviz .data .base import make_attrs
29
- from jax .experimental . maps import SerialLoop , xmap
29
+ from jax .lax import scan
30
30
from pytensor .compile import SharedVariable , Supervisor , mode
31
31
from pytensor .graph .basic import graph_inputs
32
32
from pytensor .graph .fg import FunctionGraph
@@ -175,25 +175,29 @@ def _sample_stats_to_xarray(posterior):
175
175
return data
176
176
177
177
178
+ def _device_put (input , device : str ):
179
+ return jax .device_put (input , jax .devices (device )[0 ])
180
+
181
+
178
182
def _postprocess_samples (
179
- jax_fn : List [ TensorVariable ] ,
183
+ jax_fn : Callable ,
180
184
raw_mcmc_samples : List [TensorVariable ],
181
- postprocessing_backend : str ,
182
- num_chunks : Optional [ int ] = None ,
185
+ postprocessing_backend : Literal [ "cpu" , "gpu" ] | None = None ,
186
+ postprocessing_vectorize : Literal [ "vmap" , "scan" ] = "scan" ,
183
187
) -> List [TensorVariable ]:
184
- if num_chunks is not None :
185
- loop = xmap (
186
- jax_fn ,
187
- in_axes = ["chain" , "samples" , ...],
188
- out_axes = ["chain" , "samples" , ...],
189
- axis_resources = {"samples" : SerialLoop (num_chunks )},
188
+ if postprocessing_vectorize == "scan" :
189
+ t_raw_mcmc_samples = [jnp .swapaxes (t , 0 , 1 ) for t in raw_mcmc_samples ]
190
+ jax_vfn = jax .vmap (jax_fn )
191
+ _ , outs = scan (
192
+ lambda _ , x : ((), jax_vfn (* x )),
193
+ (),
194
+ _device_put (t_raw_mcmc_samples , postprocessing_backend ),
190
195
)
191
- f = xmap (loop , in_axes = [...], out_axes = [...])
192
- return f (* jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ]))
196
+ return [jnp .swapaxes (t , 0 , 1 ) for t in outs ]
197
+ elif postprocessing_vectorize == "vmap" :
198
+ return jax .vmap (jax .vmap (jax_fn ))(* _device_put (raw_mcmc_samples , postprocessing_backend ))
193
199
else :
194
- return jax .vmap (jax .vmap (jax_fn ))(
195
- * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
196
- )
200
+ raise ValueError (f"Unrecognized postprocessing_vectorize: { postprocessing_vectorize } " )
197
201
198
202
199
203
def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> Dict :
@@ -231,12 +235,17 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
231
235
232
236
233
237
def _get_log_likelihood (
234
- model : Model , samples , backend = None , num_chunks : Optional [int ] = None
238
+ model : Model ,
239
+ samples ,
240
+ backend : Literal ["cpu" , "gpu" ] | None = None ,
241
+ postprocessing_vectorize : Literal ["vmap" , "scan" ] = "scan" ,
235
242
) -> Dict :
236
243
"""Compute log-likelihood for all observations"""
237
244
elemwise_logp = model .logp (model .observed_RVs , sum = False )
238
245
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = elemwise_logp )
239
- result = _postprocess_samples (jax_fn , samples , backend , num_chunks = num_chunks )
246
+ result = _postprocess_samples (
247
+ jax_fn , samples , backend , postprocessing_vectorize = postprocessing_vectorize
248
+ )
240
249
return {v .name : r for v , r in zip (model .observed_RVs , result )}
241
250
242
251
@@ -297,11 +306,11 @@ def _blackjax_inference_loop(
297
306
298
307
adapt = blackjax .window_adaptation (
299
308
algorithm = algorithm ,
300
- logprob_fn = logprob_fn ,
301
- num_steps = tune ,
309
+ logdensity_fn = logprob_fn ,
302
310
target_acceptance_rate = target_accept ,
303
311
)
304
- last_state , kernel , _ = adapt .run (seed , init_position )
312
+ (last_state , tuned_params ), _ = adapt .run (seed , init_position , num_steps = tune )
313
+ kernel = algorithm (logprob_fn , ** tuned_params ).step
305
314
306
315
def inference_loop (rng_key , initial_state ):
307
316
def one_step (state , rng_key ):
@@ -327,9 +336,10 @@ def sample_blackjax_nuts(
327
336
var_names : Optional [Sequence [str ]] = None ,
328
337
keep_untransformed : bool = False ,
329
338
chain_method : str = "parallel" ,
330
- postprocessing_backend : Optional [ str ] = None ,
331
- postprocessing_chunks : Optional [ int ] = None ,
339
+ postprocessing_backend : Literal [ "cpu" , "gpu" ] | None = None ,
340
+ postprocessing_vectorize : Literal [ "vmap" , "scan" ] = "scan" ,
332
341
idata_kwargs : Optional [Dict [str , Any ]] = None ,
342
+ postprocessing_chunks = None , # deprecated
333
343
) -> az .InferenceData :
334
344
"""
335
345
Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
@@ -366,12 +376,10 @@ def sample_blackjax_nuts(
366
376
chain_method : str, default "parallel"
367
377
Specify how samples should be drawn. The choices include "parallel", and
368
378
"vectorized".
369
- postprocessing_backend : str, optional
379
+ postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
370
380
Specify how postprocessing should be computed. gpu or cpu
371
- postprocessing_chunks: Optional[int], default None
372
- Specify the number of chunks the postprocessing should be computed in. More
373
- chunks reduces memory usage at the cost of losing some vectorization, None
374
- uses jax.vmap
381
+ postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
382
+ How to vectorize the postprocessing: vmap or sequential scan
375
383
idata_kwargs : dict, optional
376
384
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
377
385
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -387,6 +395,14 @@ def sample_blackjax_nuts(
387
395
with their respective sample stats and pointwise log likeihood values (unless
388
396
skipped with ``idata_kwargs``).
389
397
"""
398
+ if postprocessing_chunks is not None :
399
+ import warnings
400
+
401
+ warnings .warn (
402
+ "postprocessing_chunks is deprecated due to being unstable, "
403
+ "using postprocessing_vectorize='scan' instead" ,
404
+ DeprecationWarning ,
405
+ )
390
406
import blackjax
391
407
392
408
model = modelcontext (model )
@@ -441,14 +457,17 @@ def sample_blackjax_nuts(
441
457
442
458
states , stats = map_fn (get_posterior_samples )(keys , init_params )
443
459
raw_mcmc_samples = states .position
444
- potential_energy = states .potential_energy
460
+ potential_energy = states .logdensity
445
461
tic3 = datetime .now ()
446
462
print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
447
463
448
464
print ("Transforming variables..." , file = sys .stdout )
449
465
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
450
466
result = _postprocess_samples (
451
- jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
467
+ jax_fn ,
468
+ raw_mcmc_samples ,
469
+ postprocessing_backend = postprocessing_backend ,
470
+ postprocessing_vectorize = postprocessing_vectorize ,
452
471
)
453
472
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
454
473
mcmc_stats = _blackjax_stats_to_dict (stats , potential_energy )
@@ -467,7 +486,7 @@ def sample_blackjax_nuts(
467
486
model ,
468
487
raw_mcmc_samples ,
469
488
backend = postprocessing_backend ,
470
- num_chunks = postprocessing_chunks ,
489
+ postprocessing_vectorize = postprocessing_vectorize ,
471
490
)
472
491
tic6 = datetime .now ()
473
492
print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
@@ -527,10 +546,11 @@ def sample_numpyro_nuts(
527
546
progressbar : bool = True ,
528
547
keep_untransformed : bool = False ,
529
548
chain_method : str = "parallel" ,
530
- postprocessing_backend : Optional [ str ] = None ,
531
- postprocessing_chunks : Optional [ int ] = None ,
549
+ postprocessing_backend : Literal [ "cpu" , "gpu" ] | None = None ,
550
+ postprocessing_vectorize : Literal [ "vmap" , "scan" ] = "scan" ,
532
551
idata_kwargs : Optional [Dict ] = None ,
533
552
nuts_kwargs : Optional [Dict ] = None ,
553
+ postprocessing_chunks = None ,
534
554
) -> az .InferenceData :
535
555
"""
536
556
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
@@ -571,12 +591,10 @@ def sample_numpyro_nuts(
571
591
chain_method : str, default "parallel"
572
592
Specify how samples should be drawn. The choices include "sequential",
573
593
"parallel", and "vectorized".
574
- postprocessing_backend : Optional[str]
594
+ postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
575
595
Specify how postprocessing should be computed. gpu or cpu
576
- postprocessing_chunks: Optional[int], default None
577
- Specify the number of chunks the postprocessing should be computed in. More
578
- chunks reduces memory usage at the cost of losing some vectorization, None
579
- uses jax.vmap
596
+ postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
597
+ How to vectorize the postprocessing: vmap or sequential scan
580
598
idata_kwargs : dict, optional
581
599
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
582
600
value for the ``log_likelihood`` key to indicate that the pointwise log
@@ -594,7 +612,14 @@ def sample_numpyro_nuts(
594
612
with their respective sample stats and pointwise log likeihood values (unless
595
613
skipped with ``idata_kwargs``).
596
614
"""
615
+ if postprocessing_chunks is not None :
616
+ import warnings
597
617
618
+ warnings .warn (
619
+ "postprocessing_chunks is deprecated due to being unstable, "
620
+ "using postprocessing_vectorize='scan' instead" ,
621
+ DeprecationWarning ,
622
+ )
598
623
import numpyro
599
624
600
625
from numpyro .infer import MCMC , NUTS
@@ -667,7 +692,10 @@ def sample_numpyro_nuts(
667
692
print ("Transforming variables..." , file = sys .stdout )
668
693
jax_fn = get_jaxified_graph (inputs = model .value_vars , outputs = vars_to_sample )
669
694
result = _postprocess_samples (
670
- jax_fn , raw_mcmc_samples , postprocessing_backend , num_chunks = postprocessing_chunks
695
+ jax_fn ,
696
+ raw_mcmc_samples ,
697
+ postprocessing_backend = postprocessing_backend ,
698
+ postprocessing_vectorize = postprocessing_vectorize ,
671
699
)
672
700
mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
673
701
@@ -686,7 +714,7 @@ def sample_numpyro_nuts(
686
714
model ,
687
715
raw_mcmc_samples ,
688
716
backend = postprocessing_backend ,
689
- num_chunks = postprocessing_chunks ,
717
+ postprocessing_vectorize = postprocessing_vectorize ,
690
718
)
691
719
tic6 = datetime .now ()
692
720
print ("Log Likelihood time = " , tic6 - tic5 , file = sys .stdout )
0 commit comments