diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 390661fdc2..3439cfd470 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -168,7 +168,11 @@ def _get_log_likelihood( elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) result = _postprocess_samples( - jax_fn, samples, backend, postprocessing_vectorize=postprocessing_vectorize + jax_fn, + samples, + backend, + postprocessing_vectorize=postprocessing_vectorize, + donate_samples=False, ) return {v.name: r for v, r in zip(model.observed_RVs, result)} @@ -181,7 +185,8 @@ def _postprocess_samples( jax_fn: Callable, raw_mcmc_samples: list[TensorVariable], postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", + postprocessing_vectorize: Literal["vmap", "scan"] = "vmap", + donate_samples: bool = False, ) -> list[TensorVariable]: if postprocessing_vectorize == "scan": t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples] @@ -193,7 +198,12 @@ def _postprocess_samples( ) return [jnp.swapaxes(t, 0, 1) for t in outs] elif postprocessing_vectorize == "vmap": - return jax.vmap(jax.vmap(jax_fn))(*_device_put(raw_mcmc_samples, postprocessing_backend)) + + def process_fn(x): + return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend)) + + return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples) + else: raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}") @@ -253,7 +263,16 @@ def _blackjax_inference_loop( def _one_step(state, xs): _, rng_key = xs state, info = kernel(rng_key, state) - return state, (state, info) + position = state.position + stats = { + "diverging": info.is_divergent, + "energy": info.energy, + "tree_depth": info.num_trajectory_expansions, + "n_steps": info.num_integration_steps, + "acceptance_rate": info.acceptance_rate, + "lp": state.logdensity, + } + return state, (position, stats) progress_bar = adaptation_kwargs.pop("progress_bar", False) if progress_bar: @@ -264,43 +283,9 @@ def _one_step(state, xs): one_step = jax.jit(_one_step) keys = jax.random.split(seed, draws) - _, (states, infos) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys)) - - return states, infos - - -def _blackjax_stats_to_dict(sample_stats, potential_energy) -> dict: - """Extract compatible stats from blackjax NUTS sampler - with PyMC/Arviz naming conventions. - - Parameters - ---------- - sample_stats: NUTSInfo - Blackjax NUTSInfo object containing sampler statistics - potential_energy: ArrayLike - Potential energy values of sampled positions. + _, (samples, stats) = jax.lax.scan(one_step, last_state, (jnp.arange(draws), keys)) - Returns - ------- - Dict[str, ArrayLike] - Dictionary of sampler statistics. - """ - rename_key = { - "is_divergent": "diverging", - "energy": "energy", - "num_trajectory_expansions": "tree_depth", - "num_integration_steps": "n_steps", - "acceptance_rate": "acceptance_rate", # naming here is - "acceptance_probability": "acceptance_rate", # depending on blackjax version - } - converted_stats = {} - converted_stats["lp"] = potential_energy - for old_name, new_name in rename_key.items(): - value = getattr(sample_stats, old_name, None) - if value is None: - continue - converted_stats[new_name] = value - return converted_stats + return samples, stats def _sample_blackjax_nuts( @@ -410,11 +395,7 @@ def _sample_blackjax_nuts( **nuts_kwargs, ) - states, stats = map_fn(get_posterior_samples)(keys, initial_points) - raw_mcmc_samples = states.position - potential_energy = states.logdensity.block_until_ready() - sample_stats = _blackjax_stats_to_dict(stats, potential_energy) - + raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points) return raw_mcmc_samples, sample_stats, blackjax @@ -515,7 +496,7 @@ def sample_jax_nuts( keep_untransformed: bool = False, chain_method: str = "parallel", postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", + postprocessing_vectorize: Literal["vmap", "scan"] | None = None, postprocessing_chunks=None, idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, @@ -597,6 +578,16 @@ def sample_jax_nuts( DeprecationWarning, ) + if postprocessing_vectorize is not None: + import warnings + + warnings.warn( + 'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.', + FutureWarning, + ) + else: + postprocessing_vectorize = "vmap" + model = modelcontext(model) if var_names is not None: @@ -645,15 +636,6 @@ def sample_jax_nuts( ) tic2 = datetime.now() - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - ) - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} - if idata_kwargs is None: idata_kwargs = {} else: @@ -669,6 +651,17 @@ def sample_jax_nuts( else: log_likelihood = None + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) + result = _postprocess_samples( + jax_fn, + raw_mcmc_samples, + postprocessing_backend=postprocessing_backend, + postprocessing_vectorize=postprocessing_vectorize, + donate_samples=True, + ) + del raw_mcmc_samples + mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} + attrs = { "sampling_time": (tic2 - tic1).total_seconds(), "tuning_steps": tune, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 307f91bc1a..32d2702ff2 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -272,6 +272,7 @@ def _sample_external_nuts( var_names: Sequence[str] | None, progressbar: bool, idata_kwargs: dict | None, + compute_convergence_checks: bool, nuts_sampler_kwargs: dict | None, **kwargs, ): @@ -364,6 +365,7 @@ def _sample_external_nuts( progressbar=progressbar, nuts_sampler=sampler, idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, **nuts_sampler_kwargs, ) return idata @@ -718,6 +720,7 @@ def joined_blas_limiter(): raise ValueError( "Model can not be sampled with NUTS alone. Your model is probably not continuous." ) + with joined_blas_limiter(): return _sample_external_nuts( sampler=nuts_sampler, @@ -731,6 +734,7 @@ def joined_blas_limiter(): var_names=var_names, progressbar=progressbar, idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, )