Skip to content

Commit c3f93ba

Browse files
authored
Improve blackjax sampling integration (#6963)
* More fine tune control in blackjax sampling. * Formatting * FIx test * Enable progress bar for sampling * use standard processbar kwarg
1 parent 009da35 commit c3f93ba

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

pymc/sampling/jax.py

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import os
1516
import re
16-
import sys
1717

1818
from datetime import datetime
1919
from functools import partial
@@ -53,6 +53,8 @@
5353
get_default_varnames,
5454
)
5555

56+
logger = logging.getLogger(__name__)
57+
5658
xla_flags_env = os.getenv("XLA_FLAGS", "")
5759
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split()
5860
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
@@ -289,40 +291,46 @@ def _update_coords_and_dims(
289291
dims.update(idata_kwargs.pop("dims"))
290292

291293

292-
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
293294
def _blackjax_inference_loop(
294-
seed,
295-
init_position,
296-
logprob_fn,
297-
draws,
298-
tune,
299-
target_accept,
300-
algorithm=None,
295+
seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs
301296
):
302297
import blackjax
303298

304-
if algorithm is None:
299+
algorithm_name = adaptation_kwargs.pop("algorithm", "nuts")
300+
if algorithm_name == "nuts":
305301
algorithm = blackjax.nuts
302+
elif algorithm_name == "hmc":
303+
algorithm = blackjax.hmc
304+
else:
305+
raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.")
306306

307307
adapt = blackjax.window_adaptation(
308308
algorithm=algorithm,
309309
logdensity_fn=logprob_fn,
310310
target_acceptance_rate=target_accept,
311+
**adaptation_kwargs,
311312
)
312313
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
313314
kernel = algorithm(logprob_fn, **tuned_params).step
314315

315-
def inference_loop(rng_key, initial_state):
316-
def one_step(state, rng_key):
317-
state, info = kernel(rng_key, state)
318-
return state, (state, info)
316+
def _one_step(state, xs):
317+
_, rng_key = xs
318+
state, info = kernel(rng_key, state)
319+
return state, (state, info)
319320

320-
keys = jax.random.split(rng_key, draws)
321-
_, (states, infos) = jax.lax.scan(one_step, initial_state, keys)
321+
progress_bar = adaptation_kwargs.pop("progress_bar", False)
322+
if progress_bar:
323+
from blackjax.progress_bar import progress_bar_scan
324+
325+
logger.info("Sample with tuned parameters")
326+
one_step = jax.jit(progress_bar_scan(draws)(_one_step))
327+
else:
328+
one_step = jax.jit(_one_step)
322329

323-
return states, infos
330+
keys = jax.random.split(seed, draws)
331+
_, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys))
324332

325-
return inference_loop(seed, last_state)
333+
return states, infos
326334

327335

328336
def sample_blackjax_nuts(
@@ -334,11 +342,13 @@ def sample_blackjax_nuts(
334342
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
335343
model: Optional[Model] = None,
336344
var_names: Optional[Sequence[str]] = None,
345+
progress_bar: bool = False,
337346
keep_untransformed: bool = False,
338347
chain_method: str = "parallel",
339348
postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None,
340349
postprocessing_vectorize: Literal["vmap", "scan"] = "scan",
341350
idata_kwargs: Optional[Dict[str, Any]] = None,
351+
adaptation_kwargs: Optional[Dict[str, Any]] = None,
342352
postprocessing_chunks=None, # deprecated
343353
) -> az.InferenceData:
344354
"""
@@ -415,7 +425,7 @@ def sample_blackjax_nuts(
415425
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
416426

417427
tic1 = datetime.now()
418-
print("Compiling...", file=sys.stdout)
428+
logger.info("Compiling...")
419429

420430
init_params = _get_batched_jittered_initial_points(
421431
model=model,
@@ -432,36 +442,49 @@ def sample_blackjax_nuts(
432442
seed = jax.random.PRNGKey(random_seed)
433443
keys = jax.random.split(seed, chains)
434444

435-
get_posterior_samples = partial(
436-
_blackjax_inference_loop,
437-
logprob_fn=logprob_fn,
438-
tune=tune,
439-
draws=draws,
440-
target_accept=target_accept,
441-
)
442-
443-
tic2 = datetime.now()
444-
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
445-
446-
print("Sampling...", file=sys.stdout)
445+
if adaptation_kwargs is None:
446+
adaptation_kwargs = {}
447447

448448
# Adapted from numpyro
449449
if chain_method == "parallel":
450450
map_fn = jax.pmap
451+
if progress_bar:
452+
import warnings
453+
454+
warnings.warn(
455+
"BlackJax currently only display progress bar correctly under "
456+
"`chain_method == 'vectorized'`. Setting `progressbar=False`."
457+
)
458+
progress_bar = False
451459
elif chain_method == "vectorized":
452460
map_fn = jax.vmap
453461
else:
454462
raise ValueError(
455463
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
456464
)
457465

466+
adaptation_kwargs["progress_bar"] = progress_bar
467+
get_posterior_samples = partial(
468+
_blackjax_inference_loop,
469+
logprob_fn=logprob_fn,
470+
tune=tune,
471+
draws=draws,
472+
target_accept=target_accept,
473+
**adaptation_kwargs,
474+
)
475+
476+
tic2 = datetime.now()
477+
logger.info(f"Compilation time = {tic2 - tic1}")
478+
479+
logger.info("Sampling...")
480+
458481
states, stats = map_fn(get_posterior_samples)(keys, init_params)
459482
raw_mcmc_samples = states.position
460-
potential_energy = states.logdensity
483+
potential_energy = states.logdensity.block_until_ready()
461484
tic3 = datetime.now()
462-
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
485+
logger.info(f"Sampling time = {tic3 - tic2}")
463486

464-
print("Transforming variables...", file=sys.stdout)
487+
logger.info("Transforming variables...")
465488
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
466489
result = _postprocess_samples(
467490
jax_fn,
@@ -472,7 +495,7 @@ def sample_blackjax_nuts(
472495
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
473496
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
474497
tic4 = datetime.now()
475-
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
498+
logger.info(f"Transformation time = {tic4 - tic3}")
476499

477500
if idata_kwargs is None:
478501
idata_kwargs = {}
@@ -481,15 +504,15 @@ def sample_blackjax_nuts(
481504

482505
if idata_kwargs.pop("log_likelihood", False):
483506
tic5 = datetime.now()
484-
print("Computing Log Likelihood...", file=sys.stdout)
507+
logger.info(f"Computing Log Likelihood...")
485508
log_likelihood = _get_log_likelihood(
486509
model,
487510
raw_mcmc_samples,
488511
backend=postprocessing_backend,
489512
postprocessing_vectorize=postprocessing_vectorize,
490513
)
491514
tic6 = datetime.now()
492-
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
515+
logger.info(f"Log Likelihood time = {tic6 - tic5}")
493516
else:
494517
log_likelihood = None
495518

@@ -634,7 +657,7 @@ def sample_numpyro_nuts(
634657
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
635658

636659
tic1 = datetime.now()
637-
print("Compiling...", file=sys.stdout)
660+
logger.info("Compiling...")
638661

639662
init_params = _get_batched_jittered_initial_points(
640663
model=model,
@@ -663,9 +686,9 @@ def sample_numpyro_nuts(
663686
)
664687

665688
tic2 = datetime.now()
666-
print("Compilation time = ", tic2 - tic1, file=sys.stdout)
689+
logger.info(f"Compilation time = {tic2 - tic1}")
667690

668-
print("Sampling...", file=sys.stdout)
691+
logger.info("Sampling...")
669692

670693
map_seed = jax.random.PRNGKey(random_seed)
671694
if chains > 1:
@@ -687,9 +710,9 @@ def sample_numpyro_nuts(
687710
raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
688711

689712
tic3 = datetime.now()
690-
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
713+
logger.info(f"Sampling time = {tic3 - tic2}")
691714

692-
print("Transforming variables...", file=sys.stdout)
715+
logger.info("Transforming variables...")
693716
jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
694717
result = _postprocess_samples(
695718
jax_fn,
@@ -700,7 +723,7 @@ def sample_numpyro_nuts(
700723
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
701724

702725
tic4 = datetime.now()
703-
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
726+
logger.info(f"Transformation time = {tic4 - tic3}")
704727

705728
if idata_kwargs is None:
706729
idata_kwargs = {}
@@ -709,15 +732,17 @@ def sample_numpyro_nuts(
709732

710733
if idata_kwargs.pop("log_likelihood", False):
711734
tic5 = datetime.now()
712-
print("Computing Log Likelihood...", file=sys.stdout)
735+
logger.info(f"Computing Log Likelihood...")
713736
log_likelihood = _get_log_likelihood(
714737
model,
715738
raw_mcmc_samples,
716739
backend=postprocessing_backend,
717740
postprocessing_vectorize=postprocessing_vectorize,
718741
)
719742
tic6 = datetime.now()
720-
print("Log Likelihood time = ", tic6 - tic5, file=sys.stdout)
743+
logger.info(
744+
f"Log Likelihood time = {tic6 - tic5}",
745+
)
721746
else:
722747
log_likelihood = None
723748

pymc/sampling/mcmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def _sample_external_nuts(
372372
random_seed=random_seed,
373373
initvals=initvals,
374374
model=model,
375+
progress_bar=progressbar,
375376
idata_kwargs=idata_kwargs,
376377
**nuts_sampler_kwargs,
377378
)

0 commit comments

Comments
 (0)