Skip to content

Commit a8279d7

Browse files
authored
Pass user-provided NUTS kwargs to Numpyro (#6021)
* refactor: update user supplied NUTS kwargs with defualts for Numpyro sampler * style: fix typehint in `sample_numpyro_nuts` * doc: add missing `initvals` to docstring of `sample_numpyro_nuts` * style: return typehint for `sample_numpyro_nuts` * refactor: change `var_names` typehint from `Iterable` to `Sequence` * test: updating NUTS kwargs for Numpyro sampler function * test: use monkeypatch to determine if custum NUTS kwargs are used * test: replace monkey patch with simpler mocking method
1 parent f3ac08b commit a8279d7

File tree

2 files changed

+111
-29
lines changed

2 files changed

+111
-29
lines changed

pymc/sampling_jax.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -391,69 +391,95 @@ def sample_blackjax_nuts(
391391
return az_trace
392392

393393

394+
def _numpyro_nuts_defaults() -> Dict[str, Any]:
395+
"""Defaults parameters for Numpyro NUTS."""
396+
return {
397+
"adapt_step_size": True,
398+
"adapt_mass_matrix": True,
399+
"dense_mass": False,
400+
}
401+
402+
403+
def _update_numpyro_nuts_kwargs(nuts_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
404+
"""Update default Numpyro NUTS parameters with new values."""
405+
nuts_kwargs_defaults = _numpyro_nuts_defaults()
406+
if nuts_kwargs is not None:
407+
nuts_kwargs_defaults.update(nuts_kwargs)
408+
return nuts_kwargs_defaults
409+
410+
394411
def sample_numpyro_nuts(
395412
draws: int = 1000,
396413
tune: int = 1000,
397414
chains: int = 4,
398415
target_accept: float = 0.8,
399-
random_seed: RandomSeed = None,
416+
random_seed: Optional[RandomSeed] = None,
400417
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
401418
model: Optional[Model] = None,
402-
var_names=None,
419+
var_names: Optional[Sequence[str]] = None,
403420
progress_bar: bool = True,
404421
keep_untransformed: bool = False,
405422
chain_method: str = "parallel",
406-
postprocessing_backend: str = None,
423+
postprocessing_backend: Optional[str] = None,
407424
idata_kwargs: Optional[Dict] = None,
408425
nuts_kwargs: Optional[Dict] = None,
409-
):
426+
) -> az.InferenceData:
410427
"""
411428
Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
412429
413430
Parameters
414431
----------
415432
draws : int, default 1000
416-
The number of samples to draw. The number of tuned samples are discarded by default.
433+
The number of samples to draw. The number of tuned samples are discarded by
434+
default.
417435
tune : int, default 1000
418436
Number of iterations to tune. Samplers adjust the step sizes, scalings or
419-
similar during tuning. Tuning samples will be drawn in addition to the number specified in
420-
the ``draws`` argument.
437+
similar during tuning. Tuning samples will be drawn in addition to the number
438+
specified in the ``draws`` argument.
421439
chains : int, default 4
422440
The number of chains to sample.
423441
target_accept : float in [0, 1].
424-
The step size is tuned such that we approximate this acceptance rate. Higher values like
425-
0.9 or 0.95 often work better for problematic posteriors.
442+
The step size is tuned such that we approximate this acceptance rate. Higher
443+
values like 0.9 or 0.95 often work better for problematic posteriors.
426444
random_seed : int, RandomState or Generator, optional
427445
Random seed used by the sampling steps.
446+
initvals: StartDict or Sequence[Optional[StartDict]], optional
447+
Initial values for random variables provided as a dictionary (or sequence of
448+
dictionaries) mapping the random variable (by name or reference) to desired
449+
starting values.
428450
model : Model, optional
429-
Model to sample from. The model needs to have free random variables. When inside a ``with`` model
430-
context, it defaults to that model, otherwise the model must be passed explicitly.
431-
var_names : iterable of str, optional
432-
Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior
451+
Model to sample from. The model needs to have free random variables. When inside
452+
a ``with`` model context, it defaults to that model, otherwise the model must be
453+
passed explicitly.
454+
var_names : sequence of str, optional
455+
Names of variables for which to compute the posterior samples. Defaults to all
456+
variables in the posterior.
433457
progress_bar : bool, default True
434-
Whether or not to display a progress bar in the command line. The bar shows the percentage
435-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
436-
time until completion ("expected time of arrival"; ETA).
458+
Whether or not to display a progress bar in the command line. The bar shows the
459+
percentage of completion, the sampling speed in samples per second (SPS), and
460+
the estimated remaining time until completion ("expected time of arrival"; ETA).
437461
keep_untransformed : bool, default False
438462
Include untransformed variables in the posterior samples. Defaults to False.
439463
chain_method : str, default "parallel"
440-
Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized".
464+
Specify how samples should be drawn. The choices include "sequential",
465+
"parallel", and "vectorized".
441466
postprocessing_backend : Optional[str]
442467
Specify how postprocessing should be computed. gpu or cpu
443468
idata_kwargs : dict, optional
444-
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
445-
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
446-
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
447-
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
448-
in ``idata_kwargs``.
469+
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
470+
value for the ``log_likelihood`` key to indicate that the pointwise log
471+
likelihood should not be included in the returned object. Values for
472+
``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
473+
the ``model`` argument if not provided in ``idata_kwargs``.
449474
nuts_kwargs: dict, optional
450475
Keyword arguments for :func:`numpyro.infer.NUTS`.
451476
452477
Returns
453478
-------
454479
InferenceData
455-
ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and
456-
pointwise log likeihood values (unless skipped with ``idata_kwargs``).
480+
ArviZ ``InferenceData`` object that contains the posterior samples, together
481+
with their respective sample stats and pointwise log likeihood values (unless
482+
skipped with ``idata_kwargs``).
457483
"""
458484

459485
import numpyro
@@ -495,14 +521,10 @@ def sample_numpyro_nuts(
495521

496522
logp_fn = get_jaxified_logp(model, negative_logp=False)
497523

498-
if nuts_kwargs is None:
499-
nuts_kwargs = {}
524+
nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
500525
nuts_kernel = NUTS(
501526
potential_fn=logp_fn,
502527
target_accept_prob=target_accept,
503-
adapt_step_size=True,
504-
adapt_mass_matrix=True,
505-
dense_mass=False,
506528
**nuts_kwargs,
507529
)
508530

pymc/tests/test_sampling_jax.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Any, Dict
2+
from unittest import mock
3+
14
import aesara
25
import aesara.tensor as at
36
import jax
@@ -6,13 +9,16 @@
69

710
from aesara.compile import SharedVariable
811
from aesara.graph import graph_inputs
12+
from numpyro.infer import MCMC
913

1014
import pymc as pm
1115

1216
from pymc.sampling_jax import (
1317
_get_batched_jittered_initial_points,
1418
_get_log_likelihood,
19+
_numpyro_nuts_defaults,
1520
_replace_shared_variables,
21+
_update_numpyro_nuts_kwargs,
1622
get_jaxified_graph,
1723
get_jaxified_logp,
1824
sample_blackjax_nuts,
@@ -270,3 +276,57 @@ def test_seeding(chains, random_seed, sampler):
270276
if chains > 1:
271277
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
272278
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))
279+
280+
281+
@pytest.mark.parametrize(
282+
"nuts_kwargs",
283+
[
284+
{"adapt_step_size": False},
285+
{"adapt_mass_matrix": True},
286+
{"dense_mass": True},
287+
{"adapt_step_size": False, "adapt_mass_matrix": True, "dense_mass": True},
288+
{"fake-key": "fake-value"},
289+
],
290+
)
291+
def test_update_numpyro_nuts_kwargs(nuts_kwargs: Dict[str, Any]):
292+
original_kwargs = nuts_kwargs.copy()
293+
new_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
294+
295+
# Maintains original key-value pairs.
296+
for k, v in original_kwargs.items():
297+
assert new_kwargs[k] == v
298+
299+
for k, v in _numpyro_nuts_defaults().items():
300+
if k not in original_kwargs:
301+
assert new_kwargs[k] == v
302+
303+
304+
@mock.patch("numpyro.infer.MCMC")
305+
def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock):
306+
mocked.side_effect = MCMC
307+
308+
step_size = 0.13
309+
dense_mass = True
310+
adapt_step_size = False
311+
target_accept = 0.78
312+
313+
with pm.Model():
314+
pm.Normal("a")
315+
sample_numpyro_nuts(
316+
10,
317+
tune=10,
318+
chains=1,
319+
target_accept=target_accept,
320+
nuts_kwargs={
321+
"step_size": step_size,
322+
"dense_mass": dense_mass,
323+
"adapt_step_size": adapt_step_size,
324+
},
325+
)
326+
mocked.assert_called_once()
327+
nuts_sampler = mocked.call_args.args[0]
328+
assert nuts_sampler._step_size == step_size
329+
assert nuts_sampler._dense_mass == dense_mass
330+
assert nuts_sampler._adapt_step_size == adapt_step_size
331+
assert nuts_sampler._adapt_mass_matrix
332+
assert nuts_sampler._target_accept_prob == target_accept

0 commit comments

Comments
 (0)