|
1208 | 1208 | "# Set the chain's start state.\n",
|
1209 | 1209 | "initial_chain_state = [\n",
|
1210 | 1210 | " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda1\"),\n",
|
1211 |
| - " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\", tf.float32),\n", |
| 1211 | + " tf.cast(tf.reduce_mean(count_data), tf.float32) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\"),\n", |
1212 | 1212 | " 0.5 * tf.ones([], dtype=tf.float32, name=\"init_tau\"),\n",
|
1213 | 1213 | "]\n",
|
1214 | 1214 | "\n",
|
|
1273 | 1273 | " target_log_prob_fn=unnormalized_log_posterior,\n",
|
1274 | 1274 | " num_leapfrog_steps=2,\n",
|
1275 | 1275 | " step_size=step_size,\n",
|
1276 |
| - " step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),\n", |
| 1276 | + " step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(None),\n", |
1277 | 1277 | " state_gradients_are_stopped=True),\n",
|
1278 | 1278 | " bijector=unconstraining_bijectors))\n",
|
1279 | 1279 | "\n",
|
1280 |
| - "tau_samples = tf.floor(posterior_tau * tf.cast(tf.size(count_data)), tf.float32)\n", |
| 1280 | + "tau_samples = tf.floor(posterior_tau * tf.cast(tf.size(count_data), tf.float32))\n", |
1281 | 1281 | "\n",
|
1282 | 1282 | "# tau_samples, lambda_1_samples, lambda_2_samples contain\n",
|
1283 | 1283 | "# N samples from the corresponding posterior distribution\n",
|
|
0 commit comments