|
24 | 24 | import jax
|
25 | 25 | import jax.numpy as jnp
|
26 | 26 | import numpy as np
|
| 27 | + |
27 | 28 | from blackjax.smc import extend_params
|
28 | 29 | from blackjax.smc.resampling import systematic
|
29 | 30 | from pymc import draw, modelcontext, to_inference_data
|
@@ -126,16 +127,20 @@ def sample_smc_blackjax(
|
126 | 127 |
|
127 | 128 | if kernel == "HMC":
|
128 | 129 | mcmc_kernel = blackjax.mcmc.hmc
|
129 |
| - mcmc_parameters = extend_params(dict( |
130 |
| - step_size=inner_kernel_params["step_size"], |
131 |
| - inverse_mass_matrix=jnp.eye(posterior_dimensions), |
132 |
| - num_integration_steps=inner_kernel_params["integration_steps"]) |
| 130 | + mcmc_parameters = extend_params( |
| 131 | + dict( |
| 132 | + step_size=inner_kernel_params["step_size"], |
| 133 | + inverse_mass_matrix=jnp.eye(posterior_dimensions), |
| 134 | + num_integration_steps=inner_kernel_params["integration_steps"], |
| 135 | + ) |
133 | 136 | )
|
134 | 137 | elif kernel == "NUTS":
|
135 | 138 | mcmc_kernel = blackjax.mcmc.nuts
|
136 |
| - mcmc_parameters = extend_params(dict( |
137 |
| - step_size=inner_kernel_params["step_size"], |
138 |
| - inverse_mass_matrix=jnp.eye(posterior_dimensions)) |
| 139 | + mcmc_parameters = extend_params( |
| 140 | + dict( |
| 141 | + step_size=inner_kernel_params["step_size"], |
| 142 | + inverse_mass_matrix=jnp.eye(posterior_dimensions), |
| 143 | + ) |
139 | 144 | )
|
140 | 145 | else:
|
141 | 146 | raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")
|
|
0 commit comments