|
| 1 | +# pylint: skip-file |
| 2 | +import os |
| 3 | +import re |
| 4 | +import warnings |
| 5 | + |
| 6 | +xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--") |
| 7 | +xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split() |
| 8 | +os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)]) |
| 9 | + |
| 10 | +import arviz as az |
| 11 | +import jax |
| 12 | +import numpy as np |
| 13 | +import pandas as pd |
| 14 | +import theano |
| 15 | +import theano.sandbox.jax_linker |
| 16 | +import theano.sandbox.jaxify |
| 17 | + |
| 18 | +import pymc3 as pm |
| 19 | + |
| 20 | +from pymc3 import modelcontext |
| 21 | + |
| 22 | +warnings.warn("This module is experimental.") |
| 23 | + |
| 24 | +# Disable C compilation by default |
| 25 | +# theano.config.cxx = "" |
| 26 | +# This will make the JAX Linker the default |
| 27 | +# theano.config.mode = "JAX" |
| 28 | + |
| 29 | + |
| 30 | +def sample_tfp_nuts( |
| 31 | + draws=1000, |
| 32 | + tune=1000, |
| 33 | + chains=4, |
| 34 | + target_accept=0.8, |
| 35 | + random_seed=10, |
| 36 | + model=None, |
| 37 | + num_tuning_epoch=2, |
| 38 | + num_compute_step_size=500, |
| 39 | +): |
| 40 | + from tensorflow_probability.substrates import jax as tfp |
| 41 | + |
| 42 | + model = modelcontext(model) |
| 43 | + |
| 44 | + seed = jax.random.PRNGKey(random_seed) |
| 45 | + |
| 46 | + fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt]) |
| 47 | + fns = theano.sandbox.jaxify.jax_funcify(fgraph) |
| 48 | + logp_fn_jax = fns[0] |
| 49 | + |
| 50 | + rv_names = [rv.name for rv in model.free_RVs] |
| 51 | + init_state = [model.test_point[rv_name] for rv_name in rv_names] |
| 52 | + init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) |
| 53 | + |
| 54 | + @jax.pmap |
| 55 | + def _sample(init_state, seed): |
| 56 | + def gen_kernel(step_size): |
| 57 | + hmc = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=logp_fn_jax, step_size=step_size) |
| 58 | + return tfp.mcmc.DualAveragingStepSizeAdaptation( |
| 59 | + hmc, tune // num_tuning_epoch, target_accept_prob=target_accept |
| 60 | + ) |
| 61 | + |
| 62 | + def trace_fn(_, pkr): |
| 63 | + return pkr.new_step_size |
| 64 | + |
| 65 | + def get_tuned_stepsize(samples, step_size): |
| 66 | + return step_size[-1] * jax.numpy.std(samples[-num_compute_step_size:]) |
| 67 | + |
| 68 | + step_size = jax.tree_map(jax.numpy.ones_like, init_state) |
| 69 | + for i in range(num_tuning_epoch - 1): |
| 70 | + tuning_hmc = gen_kernel(step_size) |
| 71 | + init_samples, tuning_result, kernel_results = tfp.mcmc.sample_chain( |
| 72 | + num_results=tune // num_tuning_epoch, |
| 73 | + current_state=init_state, |
| 74 | + kernel=tuning_hmc, |
| 75 | + trace_fn=trace_fn, |
| 76 | + return_final_kernel_results=True, |
| 77 | + seed=seed, |
| 78 | + ) |
| 79 | + |
| 80 | + step_size = jax.tree_multimap(get_tuned_stepsize, list(init_samples), tuning_result) |
| 81 | + init_state = [x[-1] for x in init_samples] |
| 82 | + |
| 83 | + # Run inference |
| 84 | + sample_kernel = gen_kernel(step_size) |
| 85 | + mcmc_samples, leapfrog_num = tfp.mcmc.sample_chain( |
| 86 | + num_results=draws, |
| 87 | + num_burnin_steps=tune // num_tuning_epoch, |
| 88 | + current_state=init_state, |
| 89 | + kernel=sample_kernel, |
| 90 | + trace_fn=lambda _, pkr: pkr.inner_results.leapfrogs_taken, |
| 91 | + seed=seed, |
| 92 | + ) |
| 93 | + |
| 94 | + return mcmc_samples, leapfrog_num |
| 95 | + |
| 96 | + print("Compiling...") |
| 97 | + tic2 = pd.Timestamp.now() |
| 98 | + map_seed = jax.random.split(seed, chains) |
| 99 | + mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed) |
| 100 | + tic3 = pd.Timestamp.now() |
| 101 | + print("Compilation + sampling time = ", tic3 - tic2) |
| 102 | + |
| 103 | + # map_seed = jax.random.split(seed, chains) |
| 104 | + # mcmc_samples = _sample(init_state_batched, map_seed) |
| 105 | + # tic4 = pd.Timestamp.now() |
| 106 | + # print("Sampling time = ", tic4 - tic3) |
| 107 | + |
| 108 | + posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} |
| 109 | + |
| 110 | + az_trace = az.from_dict(posterior=posterior) |
| 111 | + return az_trace # , leapfrog_num, tic3 - tic2 |
| 112 | + |
| 113 | + import jax |
| 114 | + |
| 115 | + |
| 116 | +def sample_numpyro_nuts( |
| 117 | + draws=1000, |
| 118 | + tune=1000, |
| 119 | + chains=4, |
| 120 | + target_accept=0.8, |
| 121 | + random_seed=10, |
| 122 | + model=None, |
| 123 | + progress_bar=True, |
| 124 | +): |
| 125 | + from numpyro.infer import MCMC, NUTS |
| 126 | + |
| 127 | + from pymc3 import modelcontext |
| 128 | + |
| 129 | + model = modelcontext(model) |
| 130 | + |
| 131 | + seed = jax.random.PRNGKey(random_seed) |
| 132 | + |
| 133 | + fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt]) |
| 134 | + fns = theano.sandbox.jaxify.jax_funcify(fgraph) |
| 135 | + logp_fn_jax = fns[0] |
| 136 | + |
| 137 | + rv_names = [rv.name for rv in model.free_RVs] |
| 138 | + init_state = [model.test_point[rv_name] for rv_name in rv_names] |
| 139 | + init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) |
| 140 | + |
| 141 | + @jax.jit |
| 142 | + def _sample(current_state, seed): |
| 143 | + step_size = jax.tree_map(jax.numpy.ones_like, init_state) |
| 144 | + nuts_kernel = NUTS( |
| 145 | + potential_fn=lambda x: -logp_fn_jax(*x), |
| 146 | + # model=model, |
| 147 | + target_accept_prob=target_accept, |
| 148 | + adapt_step_size=True, |
| 149 | + adapt_mass_matrix=True, |
| 150 | + dense_mass=False, |
| 151 | + ) |
| 152 | + |
| 153 | + pmap_numpyro = MCMC( |
| 154 | + nuts_kernel, |
| 155 | + num_warmup=tune, |
| 156 | + num_samples=draws, |
| 157 | + num_chains=chains, |
| 158 | + postprocess_fn=None, |
| 159 | + chain_method="parallel", |
| 160 | + progress_bar=progress_bar, |
| 161 | + ) |
| 162 | + |
| 163 | + pmap_numpyro.run(seed, init_params=current_state, extra_fields=("num_steps",)) |
| 164 | + samples = pmap_numpyro.get_samples(group_by_chain=True) |
| 165 | + leapfrogs_taken = pmap_numpyro.get_extra_fields(group_by_chain=True)["num_steps"] |
| 166 | + return samples, leapfrogs_taken |
| 167 | + |
| 168 | + print("Compiling...") |
| 169 | + tic2 = pd.Timestamp.now() |
| 170 | + map_seed = jax.random.split(seed, chains) |
| 171 | + mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed) |
| 172 | + tic3 = pd.Timestamp.now() |
| 173 | + print("Compilation + sampling time = ", tic3 - tic2) |
| 174 | + |
| 175 | + # map_seed = jax.random.split(seed, chains) |
| 176 | + # mcmc_samples = _sample(init_state_batched, map_seed) |
| 177 | + # tic4 = pd.Timestamp.now() |
| 178 | + # print("Sampling time = ", tic4 - tic3) |
| 179 | + |
| 180 | + posterior = {k: v for k, v in zip(rv_names, mcmc_samples)} |
| 181 | + |
| 182 | + az_trace = az.from_dict(posterior=posterior) |
| 183 | + return az_trace # , leapfrogs_taken, tic3 - tic2 |
0 commit comments