Skip to content

Commit ebd3f9c

Browse files
committed
Integrating Blackjax's SMC into PyMC
1 parent 7e1af48 commit ebd3f9c

File tree

9 files changed

+752
-61
lines changed

9 files changed

+752
-61
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ jobs:
191191
test-subset:
192192
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
193193
- tests/model/test_core.py tests/sampling/test_mcmc.py
194-
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
194+
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/smc/from_blackjax/test_blackjax_smc.py tests/sampling/test_parallel.py
195195
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py
196196

197197
fail-fast: false

pymc/smc/from_blackjax/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

pymc/smc/from_blackjax/kernels.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import blackjax
15+
16+
from blackjax.smc.resampling import systematic
17+
from jax import numpy as jnp
18+
19+
20+
def build_smc_with_hmc_kernel(
21+
prior_log_prob,
22+
loglikelihood,
23+
posterior_dimensions,
24+
target_ess,
25+
num_mcmc_steps,
26+
kernel_parameters,
27+
):
28+
return build_blackjax_smc(
29+
prior_log_prob,
30+
loglikelihood,
31+
blackjax.mcmc.hmc,
32+
mcmc_parameters=dict(
33+
step_size=kernel_parameters["step_size"],
34+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
35+
num_integration_steps=kernel_parameters["integration_steps"],
36+
),
37+
target_ess=target_ess,
38+
num_mcmc_steps=num_mcmc_steps,
39+
)
40+
41+
42+
def build_smc_with_nuts_kernel(
43+
prior_log_prob,
44+
loglikelihood,
45+
posterior_dimensions,
46+
target_ess,
47+
num_mcmc_steps,
48+
kernel_parameters,
49+
):
50+
return build_blackjax_smc(
51+
prior_log_prob,
52+
loglikelihood,
53+
blackjax.mcmc.nuts,
54+
mcmc_parameters=dict(
55+
step_size=kernel_parameters["step_size"],
56+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
57+
),
58+
target_ess=target_ess,
59+
num_mcmc_steps=num_mcmc_steps,
60+
)
61+
62+
63+
def build_blackjax_smc(
64+
prior_log_prob, loglikelihood, sampler_module, mcmc_parameters, target_ess, num_mcmc_steps
65+
):
66+
return blackjax.adaptive_tempered_smc(
67+
prior_log_prob,
68+
loglikelihood,
69+
sampler_module.build_kernel(),
70+
sampler_module.init,
71+
mcmc_parameters=mcmc_parameters,
72+
resampling_fn=systematic,
73+
target_ess=target_ess,
74+
num_mcmc_steps=num_mcmc_steps,
75+
)

0 commit comments

Comments
 (0)