395
395
"cumulative_headcounts = tf.gather(tf.cumsum(coin_flip_data), num_trials)\n",
396
396
"\n",
397
397
"rv_observed_heads = tfp.distributions.Beta(\n",
398
- " concentration1=tf.to_float (1 + cumulative_headcounts),\n",
399
- " concentration0=tf.to_float (1 + num_trials - cumulative_headcounts))\n",
398
+ " concentration1=tf.cast (1 + cumulative_headcounts, tf.float32 ),\n",
399
+ " concentration0=tf.cast (1 + num_trials - cumulative_headcounts, tf.float32 ))\n",
400
400
"\n",
401
401
"probs_of_heads = tf.linspace(start=0., stop=1., num=100, name=\"linspace\")\n",
402
402
"observed_probs_heads = tf.transpose(rv_observed_heads.prob(probs_of_heads[:, tf.newaxis]))"
1207
1207
"source": [
1208
1208
"# Set the chain's start state.\n",
1209
1209
"initial_chain_state = [\n",
1210
- " tf.to_float (tf.reduce_mean(count_data)) * tf.ones([], dtype=tf.float32, name=\"init_lambda1\"),\n",
1211
- " tf.to_float (tf.reduce_mean(count_data)) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\"),\n",
1210
+ " tf.cast (tf.reduce_mean(count_data)) * tf.ones([], dtype=tf.float32, name=\"init_lambda1\", tf.float32 ),\n",
1211
+ " tf.cast (tf.reduce_mean(count_data)) * tf.ones([], dtype=tf.float32, name=\"init_lambda2\", tf.float32 ),\n",
1212
1212
" 0.5 * tf.ones([], dtype=tf.float32, name=\"init_tau\"),\n",
1213
1213
"]\n",
1214
1214
"\n",
1234
1234
"\n",
1235
1235
" lambda_ = tf.gather(\n",
1236
1236
" [lambda_1, lambda_2],\n",
1237
- " indices=tf.to_int32(tau * tf.to_float (tf.size(count_data)) <= tf.to_float (tf.range(tf.size(count_data)))))\n",
1237
+ " indices=tf.to_int32(tau * tf.cast (tf.size(count_data), tf.float32 ) <= tf.cast (tf.range(tf.size(count_data)))), tf.float32 )\n",
1238
1238
" rv_observation = tfd.Poisson(rate=lambda_)\n",
1239
1239
" \n",
1240
1240
" return (\n",
1277
1277
" state_gradients_are_stopped=True),\n",
1278
1278
" bijector=unconstraining_bijectors))\n",
1279
1279
"\n",
1280
- "tau_samples = tf.floor(posterior_tau * tf.to_float (tf.size(count_data)))\n",
1280
+ "tau_samples = tf.floor(posterior_tau * tf.cast (tf.size(count_data)), tf.float32 )\n",
1281
1281
"\n",
1282
1282
"# tau_samples, lambda_1_samples, lambda_2_samples contain\n",
1283
1283
"# N samples from the corresponding posterior distribution\n",
1644
1644
]
1645
1645
}
1646
1646
]
1647
- }
1647
+ }
0 commit comments