Skip to content

Commit eb9681e

Browse files
authored
Import jax in TFP function. Fix time measurement of sampler. (#4290)
1 parent 8c12820 commit eb9681e

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

pymc3/sampling_jax.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def sample_tfp_nuts(
3838
num_compute_step_size=500,
3939
):
4040
from tensorflow_probability.substrates import jax as tfp
41-
41+
import jax
4242
model = modelcontext(model)
4343

4444
seed = jax.random.PRNGKey(random_seed)
@@ -97,8 +97,6 @@ def get_tuned_stepsize(samples, step_size):
9797
tic2 = pd.Timestamp.now()
9898
map_seed = jax.random.split(seed, chains)
9999
mcmc_samples, leapfrog_num = _sample(init_state_batched, map_seed)
100-
tic3 = pd.Timestamp.now()
101-
print("Compilation + sampling time = ", tic3 - tic2)
102100

103101
# map_seed = jax.random.split(seed, chains)
104102
# mcmc_samples = _sample(init_state_batched, map_seed)
@@ -108,10 +106,10 @@ def get_tuned_stepsize(samples, step_size):
108106
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
109107

110108
az_trace = az.from_dict(posterior=posterior)
109+
tic3 = pd.Timestamp.now()
110+
print("Compilation + sampling time = ", tic3 - tic2)
111111
return az_trace # , leapfrog_num, tic3 - tic2
112112

113-
import jax
114-
115113

116114
def sample_numpyro_nuts(
117115
draws=1000,
@@ -169,9 +167,6 @@ def _sample(current_state, seed):
169167
tic2 = pd.Timestamp.now()
170168
map_seed = jax.random.split(seed, chains)
171169
mcmc_samples, leapfrogs_taken = _sample(init_state_batched, map_seed)
172-
tic3 = pd.Timestamp.now()
173-
print("Compilation + sampling time = ", tic3 - tic2)
174-
175170
# map_seed = jax.random.split(seed, chains)
176171
# mcmc_samples = _sample(init_state_batched, map_seed)
177172
# tic4 = pd.Timestamp.now()
@@ -180,4 +175,6 @@ def _sample(current_state, seed):
180175
posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
181176

182177
az_trace = az.from_dict(posterior=posterior)
178+
tic3 = pd.Timestamp.now()
179+
print("Compilation + sampling time = ", tic3 - tic2)
183180
return az_trace # , leapfrogs_taken, tic3 - tic2

0 commit comments

Comments
 (0)