|
6 | 6 |
|
7 | 7 | from typing import Callable, List, Optional
|
8 | 8 |
|
| 9 | +from pymc.sampling import _init_jitter |
| 10 | + |
9 | 11 | xla_flags = os.getenv("XLA_FLAGS", "")
|
10 | 12 | xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
|
11 | 13 | os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags)
|
@@ -139,12 +141,46 @@ def _get_log_likelihood(model, samples):
|
139 | 141 | return data
|
140 | 142 |
|
141 | 143 |
|
| 144 | +def _get_batched_jittered_initial_points( |
| 145 | + model, chains, initvals, random_seed, jitter=True, jitter_max_retries=10 |
| 146 | +): |
| 147 | + """Get jittered initial point in format expected by NumPyro MCMC kernel |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + out: list of ndarrays |
| 152 | + list with one item per variable and number of chains as batch dimension. |
| 153 | + Each item has shape `(chains, *var.shape)` |
| 154 | + """ |
| 155 | + if isinstance(random_seed, (int, np.integer)): |
| 156 | + random_seed = np.random.default_rng(random_seed).integers(2**30, size=chains) |
| 157 | + elif not isinstance(random_seed, (list, tuple, np.ndarray)): |
| 158 | + raise ValueError(f"The `seeds` must be int or array-like. Got {type(random_seed)} instead.") |
| 159 | + |
| 160 | + assert len(random_seed) == chains |
| 161 | + |
| 162 | + initial_points = _init_jitter( |
| 163 | + model, |
| 164 | + initvals, |
| 165 | + seeds=random_seed, |
| 166 | + jitter=jitter, |
| 167 | + jitter_max_retries=jitter_max_retries, |
| 168 | + ) |
| 169 | + initial_points = [list(initial_point.values()) for initial_point in initial_points] |
| 170 | + if chains == 1: |
| 171 | + initial_points = initial_points[0] |
| 172 | + else: |
| 173 | + initial_points = [np.stack(init_state) for init_state in zip(*initial_points)] |
| 174 | + return initial_points |
| 175 | + |
| 176 | + |
142 | 177 | def sample_numpyro_nuts(
|
143 | 178 | draws=1000,
|
144 | 179 | tune=1000,
|
145 | 180 | chains=4,
|
146 | 181 | target_accept=0.8,
|
147 |
| - random_seed=10, |
| 182 | + random_seed=None, |
| 183 | + initvals=None, |
148 | 184 | model=None,
|
149 | 185 | var_names=None,
|
150 | 186 | progress_bar=True,
|
@@ -176,13 +212,20 @@ def sample_numpyro_nuts(
|
176 | 212 | else:
|
177 | 213 | dims = {}
|
178 | 214 |
|
| 215 | + if random_seed is None: |
| 216 | + random_seed = model.rng_seeder.randint( |
| 217 | + 2**30, dtype=np.int64, size=chains if chains > 1 else None |
| 218 | + ) |
| 219 | + |
179 | 220 | tic1 = datetime.now()
|
180 | 221 | print("Compiling...", file=sys.stdout)
|
181 | 222 |
|
182 |
| - rv_names = [rv.name for rv in model.value_vars] |
183 |
| - initial_point = model.compute_initial_point() |
184 |
| - init_state = [initial_point[rv_name] for rv_name in rv_names] |
185 |
| - init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state) |
| 223 | + init_params = _get_batched_jittered_initial_points( |
| 224 | + model=model, |
| 225 | + chains=chains, |
| 226 | + initvals=initvals, |
| 227 | + random_seed=random_seed, |
| 228 | + ) |
186 | 229 |
|
187 | 230 | logp_fn = get_jaxified_logp(model)
|
188 | 231 |
|
@@ -212,14 +255,9 @@ def sample_numpyro_nuts(
|
212 | 255 |
|
213 | 256 | print("Sampling...", file=sys.stdout)
|
214 | 257 |
|
215 |
| - seed = jax.random.PRNGKey(random_seed) |
216 |
| - map_seed = jax.random.split(seed, chains) |
217 |
| - |
218 |
| - if chains == 1: |
219 |
| - init_params = init_state |
220 |
| - map_seed = seed |
221 |
| - else: |
222 |
| - init_params = init_state_batched |
| 258 | + map_seed = jax.random.PRNGKey(random_seed) |
| 259 | + if chains > 1: |
| 260 | + map_seed = jax.random.split(map_seed, chains) |
223 | 261 |
|
224 | 262 | pmap_numpyro.run(
|
225 | 263 | map_seed,
|
|
0 commit comments