Skip to content

Commit b9f6b3c

Browse files
committed
bump blackjax in test environment
1 parent 3b65bb1 commit b9f6b3c

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ dependencies:
1111
- statsmodels
1212
- pip:
1313
- pymc>=5.16.1 # CI was failing to resolve
14-
- blackjax
14+
- blackjax>=1.2.3
1515
- scikit-learn

pymc_experimental/inference/smc/sampling.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import jax
2525
import jax.numpy as jnp
2626
import numpy as np
27+
2728
from blackjax.smc import extend_params
2829
from blackjax.smc.resampling import systematic
2930
from pymc import draw, modelcontext, to_inference_data
@@ -126,16 +127,20 @@ def sample_smc_blackjax(
126127

127128
if kernel == "HMC":
128129
mcmc_kernel = blackjax.mcmc.hmc
129-
mcmc_parameters = extend_params(dict(
130-
step_size=inner_kernel_params["step_size"],
131-
inverse_mass_matrix=jnp.eye(posterior_dimensions),
132-
num_integration_steps=inner_kernel_params["integration_steps"])
130+
mcmc_parameters = extend_params(
131+
dict(
132+
step_size=inner_kernel_params["step_size"],
133+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
134+
num_integration_steps=inner_kernel_params["integration_steps"],
135+
)
133136
)
134137
elif kernel == "NUTS":
135138
mcmc_kernel = blackjax.mcmc.nuts
136-
mcmc_parameters = extend_params(dict(
137-
step_size=inner_kernel_params["step_size"],
138-
inverse_mass_matrix=jnp.eye(posterior_dimensions))
139+
mcmc_parameters = extend_params(
140+
dict(
141+
step_size=inner_kernel_params["step_size"],
142+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
143+
)
139144
)
140145
else:
141146
raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")

0 commit comments

Comments
 (0)