@@ -532,7 +532,7 @@ def sample_jax_nuts(
532
532
tune : int, default 1000
533
533
Number of iterations to tune. Samplers adjust the step sizes, scalings or
534
534
similar during tuning. Tuning samples will be drawn in addition to the number
535
- specified in the ``draws`` argument.
535
+ specified in the ``draws`` argument. Tuned samples are discarded.
536
536
chains : int, default 4
537
537
The number of chains to sample.
538
538
target_accept : float in [0, 1].
@@ -554,11 +554,11 @@ def sample_jax_nuts(
554
554
Names of variables for which to compute the posterior samples. Defaults to all
555
555
variables in the posterior.
556
556
nuts_kwargs : dict, optional
557
- Keyword arguments for underlying nuts sampler
558
- progressbar: bool, default True
559
- If True, display progressbar while sampling
557
+ Keyword arguments for the underlying nuts sampler
558
+ progressbar : bool, default True
559
+ If True, display a progressbar while sampling
560
560
keep_untransformed : bool, default False
561
- Include untransformed variables in the posterior samples. Defaults to False.
561
+ Include untransformed variables in the posterior samples.
562
562
chain_method : str, default "parallel"
563
563
Specify how samples should be drawn. The choices include "parallel", and
564
564
"vectorized".
@@ -575,10 +575,11 @@ def sample_jax_nuts(
575
575
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
576
576
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
577
577
``dims`` are provided, they are used to update the inferred dictionaries.
578
- compute_convergence_checks: bool, default True
579
- Compute ess and rhat values and warn if they indicate potential sampling issues.
578
+ compute_convergence_checks : bool, default True
579
+ If True, compute ess and rhat values and warn if they indicate potential sampling issues.
580
580
nuts_sampler : Literal["numpyro", "blackjax"]
581
- Nuts sampler library to use
581
+ Nuts sampler library to use - do not change - use sample_numpyro_nuts or
582
+ sample_blackjax_nuts as appropriate
582
583
583
584
Returns
584
585
-------
0 commit comments