@@ -38,7 +38,7 @@ def sample_tfp_nuts(
38
38
num_compute_step_size = 500 ,
39
39
):
40
40
from tensorflow_probability .substrates import jax as tfp
41
-
41
+ import jax
42
42
model = modelcontext (model )
43
43
44
44
seed = jax .random .PRNGKey (random_seed )
@@ -97,8 +97,6 @@ def get_tuned_stepsize(samples, step_size):
97
97
tic2 = pd .Timestamp .now ()
98
98
map_seed = jax .random .split (seed , chains )
99
99
mcmc_samples , leapfrog_num = _sample (init_state_batched , map_seed )
100
- tic3 = pd .Timestamp .now ()
101
- print ("Compilation + sampling time = " , tic3 - tic2 )
102
100
103
101
# map_seed = jax.random.split(seed, chains)
104
102
# mcmc_samples = _sample(init_state_batched, map_seed)
@@ -108,10 +106,10 @@ def get_tuned_stepsize(samples, step_size):
108
106
posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
109
107
110
108
az_trace = az .from_dict (posterior = posterior )
109
+ tic3 = pd .Timestamp .now ()
110
+ print ("Compilation + sampling time = " , tic3 - tic2 )
111
111
return az_trace # , leapfrog_num, tic3 - tic2
112
112
113
- import jax
114
-
115
113
116
114
def sample_numpyro_nuts (
117
115
draws = 1000 ,
@@ -169,9 +167,6 @@ def _sample(current_state, seed):
169
167
tic2 = pd .Timestamp .now ()
170
168
map_seed = jax .random .split (seed , chains )
171
169
mcmc_samples , leapfrogs_taken = _sample (init_state_batched , map_seed )
172
- tic3 = pd .Timestamp .now ()
173
- print ("Compilation + sampling time = " , tic3 - tic2 )
174
-
175
170
# map_seed = jax.random.split(seed, chains)
176
171
# mcmc_samples = _sample(init_state_batched, map_seed)
177
172
# tic4 = pd.Timestamp.now()
@@ -180,4 +175,6 @@ def _sample(current_state, seed):
180
175
posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
181
176
182
177
az_trace = az .from_dict (posterior = posterior )
178
+ tic3 = pd .Timestamp .now ()
179
+ print ("Compilation + sampling time = " , tic3 - tic2 )
183
180
return az_trace # , leapfrogs_taken, tic3 - tic2
0 commit comments