Skip to content

Commit 6e5f258

Browse files
andrewdipperricardoV94
andrewdipper
authored andcommitted
add sample_jax_nuts docstring
1 parent 3888d53 commit 6e5f258

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

pymc/sampling/jax.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,72 @@ def sample_jax_nuts(
521521
compute_convergence_checks: bool = True,
522522
nuts_sampler: Literal["numpyro", "blackjax"],
523523
) -> az.InferenceData:
524+
"""
525+
Draw samples from the posterior using a jax NUTS method.
526+
527+
Parameters
528+
----------
529+
draws : int, default 1000
530+
The number of samples to draw. The number of tuned samples are discarded by
531+
default.
532+
tune : int, default 1000
533+
Number of iterations to tune. Samplers adjust the step sizes, scalings or
534+
similar during tuning. Tuning samples will be drawn in addition to the number
535+
specified in the ``draws`` argument.
536+
chains : int, default 4
537+
The number of chains to sample.
538+
target_accept : float in [0, 1].
539+
The step size is tuned such that we approximate this acceptance rate. Higher
540+
values like 0.9 or 0.95 often work better for problematic posteriors.
541+
random_seed : int, RandomState or Generator, optional
542+
Random seed used by the sampling steps.
543+
initvals: StartDict or Sequence[Optional[StartDict]], optional
544+
Initial values for random variables provided as a dictionary (or sequence of
545+
dictionaries) mapping the random variable (by name or reference) to desired
546+
starting values.
547+
jitter: bool, default True
548+
If True, add jitter to initial points.
549+
model : Model, optional
550+
Model to sample from. The model needs to have free random variables. When inside
551+
a ``with`` model context, it defaults to that model, otherwise the model must be
552+
passed explicitly.
553+
var_names : sequence of str, optional
554+
Names of variables for which to compute the posterior samples. Defaults to all
555+
variables in the posterior.
556+
nuts_kwargs : dict, optional
557+
Keyword arguments for underlying nuts sampler
558+
progressbar: bool, default True
559+
If True, display progressbar while sampling
560+
keep_untransformed : bool, default False
561+
Include untransformed variables in the posterior samples. Defaults to False.
562+
chain_method : str, default "parallel"
563+
Specify how samples should be drawn. The choices include "parallel", and
564+
"vectorized".
565+
postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
566+
Specify how postprocessing should be computed. gpu or cpu
567+
postprocessing_vectorize : Literal["vmap", "scan"], default "scan"
568+
How to vectorize the postprocessing: vmap or sequential scan
569+
postprocessing_chunks : None
570+
This argument is deprecated
571+
idata_kwargs : dict, optional
572+
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
573+
value for the ``log_likelihood`` key to indicate that the pointwise log
574+
likelihood should not be included in the returned object. Values for
575+
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
576+
the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
577+
``dims`` are provided, they are used to update the inferred dictionaries.
578+
compute_convergence_checks: bool, default True
579+
Compute ess and rhat values and warn if they indicate potential sampling issues.
580+
nuts_sampler : Literal["numpyro", "blackjax"]
581+
Nuts sampler library to use
582+
583+
Returns
584+
-------
585+
InferenceData
586+
ArviZ ``InferenceData`` object that contains the posterior samples, together
587+
with their respective sample stats and pointwise log likeihood values (unless
588+
skipped with ``idata_kwargs``).
589+
"""
524590
if postprocessing_chunks is not None:
525591
import warnings
526592

0 commit comments

Comments
 (0)