Skip to content

Commit 7b0cf5d

Browse files
authored
Update developer guide (#3632)
- TF2.0 API changes - We dont have the same shape error any more.
1 parent 18c78f4 commit 7b0cf5d

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

docs/source/developer_guide.rst

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,23 @@ explicit about the conversion. For example:
151151
with pm.Model() as model:
152152
z = pm.Normal('z', mu=0., sigma=5.) # ==> pymc3.model.FreeRV, or theano.tensor with logp
153153
x = pm.Normal('x', mu=z, sigma=1., observed=5.) # ==> pymc3.model.ObservedRV, also has logp properties
154-
x.logp({'z': 2.5}) # ==> -4.0439386
155-
model.logp({'z': 2.5}) # ==> -6.6973152
154+
x.logp({'z': 2.5}) # ==> -4.0439386
155+
model.logp({'z': 2.5}) # ==> -6.6973152
156156
157157
**TFP**
158158
159159
.. code:: python
160160
161-
z_dist = tfd.Normal(loc=0., scale=5.) # ==> <class 'tfp.python.distributions.normal.Normal'>
162-
z = z_dist.sample() # ==> <class 'tensorflow.python.framework.ops.Tensor'>
163-
x = tfd.Normal(loc=z, scale=1.).log_prob(5.) # ==> <class 'tensorflow.python.framework.ops.Tensor'>
164-
model_logp = z_dist.log_prob(z) + x
165-
sess = tf.Session()
166-
sess.run(x, feed_dict={z: 2.5}) # ==> -4.0439386
167-
sess.run(model_logp, feed_dict={z: 2.5}) # ==> -6.6973152
161+
import tensorflow.compat.v1 as tf
162+
from tensorflow_probability import distributions as tfd
163+
164+
with tf.Session() as sess:
165+
z_dist = tfd.Normal(loc=0., scale=5.) # ==> <class 'tfp.python.distributions.normal.Normal'>
166+
z = z_dist.sample() # ==> <class 'tensorflow.python.framework.ops.Tensor'>
167+
x = tfd.Normal(loc=z, scale=1.).log_prob(5.) # ==> <class 'tensorflow.python.framework.ops.Tensor'>
168+
model_logp = z_dist.log_prob(z) + x
169+
print(sess.run(x, feed_dict={z: 2.5})) # ==> -4.0439386
170+
print(sess.run(model_logp, feed_dict={z: 2.5})) # ==> -6.6973152
168171
169172
**pyro**
170173
@@ -1046,7 +1049,10 @@ edge case especially in high dimensions. The biggest pain point is the
10461049
automatic broadcasting. As in the batch random generation, we want to
10471050
generate (n\_sample, ) + RV.shape random samples. In some cases, where
10481051
we broadcast RV1 and RV2 to create a RV3 that has one more batch shape,
1049-
we get error (even worse, wrong answer with silent error):
1052+
we get error (even worse, wrong answer with silent error).
1053+
1054+
The good news is, we are fixing these errors with the amazing works from [lucianopaz](https://github.com/lucianopaz) and
1055+
others. The challenge and some summary of the solution could be found in Luciano's [blog post](https://lucianopaz.github.io/2019/08/19/pymc3-shape-handling/)
10501056
10511057
.. code:: python
10521058
@@ -1056,11 +1062,11 @@ we get error (even worse, wrong answer with silent error):
10561062
pm.Normal('x', mu=mu, sigma=sd, observed=np.random.randn(2, 5, 10))
10571063
trace = pm.sample_prior_predictive(100)
10581064
1059-
trace['x'].shape # ==> should be (100, 2, 5, 10), but get (100, 5, 10)
1065+
trace['x'].shape # ==> should be (100, 2, 5, 10)
10601066
10611067
.. code:: python
10621068
1063-
pm.Normal.dist(mu=np.zeros(2), sigma=1).random(size=(10, 4)) # ==> ERROR
1069+
pm.Normal.dist(mu=np.zeros(2), sigma=1).random(size=(10, 4))
10641070
10651071
There are also other error related random sample generation (e.g.,
10661072
`Mixture is currently

0 commit comments

Comments
 (0)