@@ -521,6 +521,72 @@ def sample_jax_nuts(
521
521
compute_convergence_checks : bool = True ,
522
522
nuts_sampler : Literal ["numpyro" , "blackjax" ],
523
523
) -> az .InferenceData :
524
+ """
525
+ Draw samples from the posterior using a jax NUTS method.
526
+
527
+ Parameters
528
+ ----------
529
+ draws : int, default 1000
530
+ The number of samples to draw. The number of tuned samples are discarded by
531
+ default.
532
+ tune : int, default 1000
533
+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
534
+ similar during tuning. Tuning samples will be drawn in addition to the number
535
+ specified in the ``draws`` argument.
536
+ chains : int, default 4
537
+ The number of chains to sample.
538
+ target_accept : float in [0, 1].
539
+ The step size is tuned such that we approximate this acceptance rate. Higher
540
+ values like 0.9 or 0.95 often work better for problematic posteriors.
541
+ random_seed : int, RandomState or Generator, optional
542
+ Random seed used by the sampling steps.
543
+ initvals: StartDict or Sequence[Optional[StartDict]], optional
544
+ Initial values for random variables provided as a dictionary (or sequence of
545
+ dictionaries) mapping the random variable (by name or reference) to desired
546
+ starting values.
547
+ jitter: bool, default True
548
+ If True, add jitter to initial points.
549
+ model : Model, optional
550
+ Model to sample from. The model needs to have free random variables. When inside
551
+ a ``with`` model context, it defaults to that model, otherwise the model must be
552
+ passed explicitly.
553
+ var_names : sequence of str, optional
554
+ Names of variables for which to compute the posterior samples. Defaults to all
555
+ variables in the posterior.
556
+ nuts_kwargs : dict, optional
557
+ Keyword arguments for underlying nuts sampler
558
+ progressbar: bool, default True
559
+ If True, display progressbar while sampling
560
+ keep_untransformed : bool, default False
561
+ Include untransformed variables in the posterior samples. Defaults to False.
562
+ chain_method : str, default "parallel"
563
+ Specify how samples should be drawn. The choices include "parallel", and
564
+ "vectorized".
565
+ postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
566
+ Specify how postprocessing should be computed. gpu or cpu
567
+ postprocessing_vectorize : Literal["vmap", "scan"], default "scan"
568
+ How to vectorize the postprocessing: vmap or sequential scan
569
+ postprocessing_chunks : None
570
+ This argument is deprecated
571
+ idata_kwargs : dict, optional
572
+ Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
573
+ value for the ``log_likelihood`` key to indicate that the pointwise log
574
+ likelihood should not be included in the returned object. Values for
575
+ ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
576
+ the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
577
+ ``dims`` are provided, they are used to update the inferred dictionaries.
578
+ compute_convergence_checks: bool, default True
579
+ Compute ess and rhat values and warn if they indicate potential sampling issues.
580
+ nuts_sampler : Literal["numpyro", "blackjax"]
581
+ Nuts sampler library to use
582
+
583
+ Returns
584
+ -------
585
+ InferenceData
586
+ ArviZ ``InferenceData`` object that contains the posterior samples, together
587
+ with their respective sample stats and pointwise log likeihood values (unless
588
+ skipped with ``idata_kwargs``).
589
+ """
524
590
if postprocessing_chunks is not None :
525
591
import warnings
526
592
0 commit comments