@@ -391,69 +391,95 @@ def sample_blackjax_nuts(
391
391
return az_trace
392
392
393
393
394
+ def _numpyro_nuts_defaults () -> Dict [str , Any ]:
395
+ """Defaults parameters for Numpyro NUTS."""
396
+ return {
397
+ "adapt_step_size" : True ,
398
+ "adapt_mass_matrix" : True ,
399
+ "dense_mass" : False ,
400
+ }
401
+
402
+
403
+ def _update_numpyro_nuts_kwargs (nuts_kwargs : Optional [Dict [str , Any ]]) -> Dict [str , Any ]:
404
+ """Update default Numpyro NUTS parameters with new values."""
405
+ nuts_kwargs_defaults = _numpyro_nuts_defaults ()
406
+ if nuts_kwargs is not None :
407
+ nuts_kwargs_defaults .update (nuts_kwargs )
408
+ return nuts_kwargs_defaults
409
+
410
+
394
411
def sample_numpyro_nuts (
395
412
draws : int = 1000 ,
396
413
tune : int = 1000 ,
397
414
chains : int = 4 ,
398
415
target_accept : float = 0.8 ,
399
- random_seed : RandomSeed = None ,
416
+ random_seed : Optional [ RandomSeed ] = None ,
400
417
initvals : Optional [Union [StartDict , Sequence [Optional [StartDict ]]]] = None ,
401
418
model : Optional [Model ] = None ,
402
- var_names = None ,
419
+ var_names : Optional [ Sequence [ str ]] = None ,
403
420
progress_bar : bool = True ,
404
421
keep_untransformed : bool = False ,
405
422
chain_method : str = "parallel" ,
406
- postprocessing_backend : str = None ,
423
+ postprocessing_backend : Optional [ str ] = None ,
407
424
idata_kwargs : Optional [Dict ] = None ,
408
425
nuts_kwargs : Optional [Dict ] = None ,
409
- ):
426
+ ) -> az . InferenceData :
410
427
"""
411
428
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
412
429
413
430
Parameters
414
431
----------
415
432
draws : int, default 1000
416
- The number of samples to draw. The number of tuned samples are discarded by default.
433
+ The number of samples to draw. The number of tuned samples are discarded by
434
+ default.
417
435
tune : int, default 1000
418
436
Number of iterations to tune. Samplers adjust the step sizes, scalings or
419
- similar during tuning. Tuning samples will be drawn in addition to the number specified in
420
- the ``draws`` argument.
437
+ similar during tuning. Tuning samples will be drawn in addition to the number
438
+ specified in the ``draws`` argument.
421
439
chains : int, default 4
422
440
The number of chains to sample.
423
441
target_accept : float in [0, 1].
424
- The step size is tuned such that we approximate this acceptance rate. Higher values like
425
- 0.9 or 0.95 often work better for problematic posteriors.
442
+ The step size is tuned such that we approximate this acceptance rate. Higher
443
+ values like 0.9 or 0.95 often work better for problematic posteriors.
426
444
random_seed : int, RandomState or Generator, optional
427
445
Random seed used by the sampling steps.
446
+ initvals: StartDict or Sequence[Optional[StartDict]], optional
447
+ Initial values for random variables provided as a dictionary (or sequence of
448
+ dictionaries) mapping the random variable (by name or reference) to desired
449
+ starting values.
428
450
model : Model, optional
429
- Model to sample from. The model needs to have free random variables. When inside a ``with`` model
430
- context, it defaults to that model, otherwise the model must be passed explicitly.
431
- var_names : iterable of str, optional
432
- Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
451
+ Model to sample from. The model needs to have free random variables. When inside
452
+ a ``with`` model context, it defaults to that model, otherwise the model must be
453
+ passed explicitly.
454
+ var_names : sequence of str, optional
455
+ Names of variables for which to compute the posterior samples. Defaults to all
456
+ variables in the posterior.
433
457
progress_bar : bool, default True
434
- Whether or not to display a progress bar in the command line. The bar shows the percentage
435
- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
436
- time until completion ("expected time of arrival"; ETA).
458
+ Whether or not to display a progress bar in the command line. The bar shows the
459
+ percentage of completion, the sampling speed in samples per second (SPS), and
460
+ the estimated remaining time until completion ("expected time of arrival"; ETA).
437
461
keep_untransformed : bool, default False
438
462
Include untransformed variables in the posterior samples. Defaults to False.
439
463
chain_method : str, default "parallel"
440
- Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
464
+ Specify how samples should be drawn. The choices include "sequential",
465
+ "parallel", and "vectorized".
441
466
postprocessing_backend : Optional[str]
442
467
Specify how postprocessing should be computed. gpu or cpu
443
468
idata_kwargs : dict, optional
444
- Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
445
- for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
446
- not be included in the returned object. Values for ``observed_data``, ``constant_data``,
447
- ``coords ``, and ``dims`` are inferred from the ``model `` argument if not provided
448
- in ``idata_kwargs``.
469
+ Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
470
+ value for the ``log_likelihood`` key to indicate that the pointwise log
471
+ likelihood should not be included in the returned object. Values for
472
+ ``observed_data ``, ``constant_data``, ``coords``, and ``dims `` are inferred from
473
+ the ``model`` argument if not provided in ``idata_kwargs``.
449
474
nuts_kwargs: dict, optional
450
475
Keyword arguments for :func:`numpyro.infer.NUTS`.
451
476
452
477
Returns
453
478
-------
454
479
InferenceData
455
- ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
456
- pointwise log likeihood values (unless skipped with ``idata_kwargs``).
480
+ ArviZ ``InferenceData`` object that contains the posterior samples, together
481
+ with their respective sample stats and pointwise log likeihood values (unless
482
+ skipped with ``idata_kwargs``).
457
483
"""
458
484
459
485
import numpyro
@@ -495,14 +521,10 @@ def sample_numpyro_nuts(
495
521
496
522
logp_fn = get_jaxified_logp (model , negative_logp = False )
497
523
498
- if nuts_kwargs is None :
499
- nuts_kwargs = {}
524
+ nuts_kwargs = _update_numpyro_nuts_kwargs (nuts_kwargs )
500
525
nuts_kernel = NUTS (
501
526
potential_fn = logp_fn ,
502
527
target_accept_prob = target_accept ,
503
- adapt_step_size = True ,
504
- adapt_mass_matrix = True ,
505
- dense_mass = False ,
506
528
** nuts_kwargs ,
507
529
)
508
530
0 commit comments