Skip to content

Commit 7815e5b

Browse files
authored
Fix Blackjax SMC integration (#374)
1 parent e96d07f commit 7815e5b

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pymc_experimental/inference/smc/sampling.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import jax.numpy as jnp
2626
import numpy as np
2727

28+
from blackjax.smc import extend_params
2829
from blackjax.smc.resampling import systematic
2930
from pymc import draw, modelcontext, to_inference_data
3031
from pymc.backends import NDArray
@@ -126,16 +127,20 @@ def sample_smc_blackjax(
126127

127128
if kernel == "HMC":
128129
mcmc_kernel = blackjax.mcmc.hmc
129-
mcmc_parameters = dict(
130-
step_size=inner_kernel_params["step_size"],
131-
inverse_mass_matrix=jnp.eye(posterior_dimensions),
132-
num_integration_steps=inner_kernel_params["integration_steps"],
130+
mcmc_parameters = extend_params(
131+
dict(
132+
step_size=inner_kernel_params["step_size"],
133+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
134+
num_integration_steps=inner_kernel_params["integration_steps"],
135+
)
133136
)
134137
elif kernel == "NUTS":
135138
mcmc_kernel = blackjax.mcmc.nuts
136-
mcmc_parameters = dict(
137-
step_size=inner_kernel_params["step_size"],
138-
inverse_mass_matrix=jnp.eye(posterior_dimensions),
139+
mcmc_parameters = extend_params(
140+
dict(
141+
step_size=inner_kernel_params["step_size"],
142+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
143+
)
139144
)
140145
else:
141146
raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")

tests/test_blackjax_smc.py

-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def fast_model():
8080
("NUTS", False, {"step_size": 0.1}),
8181
],
8282
)
83-
@pytest.mark.xfail(reason="Still need to investigate")
8483
def test_sample_smc_blackjax(kernel, check_for_integration_steps, inner_kernel_params):
8584
"""
8685
When running the two gaussians model

0 commit comments

Comments
 (0)