Skip to content

Commit b59a9eb

Browse files
andrewdipperricardoV94
andrewdipper
authored andcommitted
update sample_jax_nuts docstring
1 parent 6e5f258 commit b59a9eb

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

pymc/sampling/jax.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def sample_jax_nuts(
532532
tune : int, default 1000
533533
Number of iterations to tune. Samplers adjust the step sizes, scalings or
534534
similar during tuning. Tuning samples will be drawn in addition to the number
535-
specified in the ``draws`` argument.
535+
specified in the ``draws`` argument. Tuned samples are discarded.
536536
chains : int, default 4
537537
The number of chains to sample.
538538
target_accept : float in [0, 1].
@@ -554,11 +554,11 @@ def sample_jax_nuts(
554554
Names of variables for which to compute the posterior samples. Defaults to all
555555
variables in the posterior.
556556
nuts_kwargs : dict, optional
557-
Keyword arguments for underlying nuts sampler
558-
progressbar: bool, default True
559-
If True, display progressbar while sampling
557+
Keyword arguments for the underlying nuts sampler
558+
progressbar : bool, default True
559+
If True, display a progressbar while sampling
560560
keep_untransformed : bool, default False
561-
Include untransformed variables in the posterior samples. Defaults to False.
561+
Include untransformed variables in the posterior samples.
562562
chain_method : str, default "parallel"
563563
Specify how samples should be drawn. The choices include "parallel", and
564564
"vectorized".
@@ -575,10 +575,11 @@ def sample_jax_nuts(
575575
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
576576
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
577577
``dims`` are provided, they are used to update the inferred dictionaries.
578-
compute_convergence_checks: bool, default True
579-
Compute ess and rhat values and warn if they indicate potential sampling issues.
578+
compute_convergence_checks : bool, default True
579+
If True, compute ess and rhat values and warn if they indicate potential sampling issues.
580580
nuts_sampler : Literal["numpyro", "blackjax"]
581-
Nuts sampler library to use
581+
Nuts sampler library to use - do not change - use sample_numpyro_nuts or
582+
sample_blackjax_nuts as appropriate
582583
583584
Returns
584585
-------

0 commit comments

Comments
 (0)