Skip to content

Commit f33a654

Browse files
author
Junpeng Lao
committed
Fix #2258
1 parent af2e68f commit f33a654

File tree

3 files changed

+61
-10
lines changed

3 files changed

+61
-10
lines changed

pymc3/distributions/transforms.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..theanof import gradient
55
from . import distribution
66
from ..math import logit, invlogit
7+
from .distribution import draw_values
78
import numpy as np
89

910
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
@@ -22,6 +23,9 @@ class Transform(object):
2223
def forward(self, x):
2324
raise NotImplementedError
2425

26+
def forward_val(self, x, point):
27+
raise NotImplementedError
28+
2529
def backward(self, z):
2630
raise NotImplementedError
2731

@@ -55,6 +59,7 @@ def __init__(self, dist, transform, *args, **kwargs):
5559
arguments to Distribution"""
5660
forward = transform.forward
5761
testval = forward(dist.default())
62+
forward_val = transform.forward_val
5863

5964
self.dist = dist
6065
self.transform_used = transform
@@ -85,6 +90,9 @@ def backward(self, x):
8590

8691
def forward(self, x):
8792
return tt.log(x)
93+
94+
def forward_val(self, x, point=None):
95+
return self.forward(x)
8896

8997
def jacobian_det(self, x):
9098
return x
@@ -103,6 +111,9 @@ def backward(self, x):
103111

104112
def forward(self, x):
105113
return logit(x)
114+
115+
def forward_val(self, x, point=None):
116+
return self.forward(x)
106117

107118
logodds = LogOdds()
108119

@@ -125,6 +136,11 @@ def forward(self, x):
125136
a, b = self.a, self.b
126137
return tt.log(x - a) - tt.log(b - x)
127138

139+
def forward_val(self, x, point=None):
140+
a, b = draw_values([self.a, self.b],
141+
point=point)
142+
return tt.log(x - a) - tt.log(b - x)
143+
128144
def jacobian_det(self, x):
129145
s = tt.nnet.softplus(-x)
130146
return tt.log(self.b - self.a) - 2 * s - x
@@ -147,8 +163,12 @@ def backward(self, x):
147163

148164
def forward(self, x):
149165
a = self.a
150-
r = tt.log(x - a)
151-
return r
166+
return tt.log(x - a)
167+
168+
def forward_val(self, x, point=None):
169+
a = draw_values([self.a],
170+
point=point)[0]
171+
return tt.log(x - a)
152172

153173
def jacobian_det(self, x):
154174
return x
@@ -171,8 +191,12 @@ def backward(self, x):
171191

172192
def forward(self, x):
173193
b = self.b
174-
r = tt.log(b - x)
175-
return r
194+
return tt.log(b - x)
195+
196+
def forward_val(self, x, point=None):
197+
b = draw_values([self.b],
198+
point=point)[0]
199+
return tt.log(b - x)
176200

177201
def jacobian_det(self, x):
178202
return x
@@ -191,6 +215,9 @@ def backward(self, y):
191215
def forward(self, x):
192216
return x[:-1]
193217

218+
def forward_val(self, x, point=None):
219+
return self.forward(x)
220+
194221
def jacobian_det(self, x):
195222
return 0
196223

@@ -224,6 +251,9 @@ def forward(self, x_):
224251
y = logit(z) - eq_share
225252
return y.T
226253

254+
def forward_val(self, x, point=None):
255+
return self.forward(x)
256+
227257
def backward(self, y_):
228258
y = y_.T
229259
Km1 = y.shape[0]
@@ -262,6 +292,9 @@ def backward(self, y):
262292
def forward(self, x):
263293
return tt.as_tensor_variable(x)
264294

295+
def forward_val(self, x, point=None):
296+
return self.forward(x)
297+
265298
def jacobian_det(self, x):
266299
return 0
267300

@@ -280,5 +313,8 @@ def backward(self, x):
280313
def forward(self, y):
281314
return tt.advanced_set_subtensor1(y, tt.log(y[self.diag_idxs]), self.diag_idxs)
282315

316+
def forward_val(self, x, point=None):
317+
return self.forward(x)
318+
283319
def jacobian_det(self, y):
284320
return tt.sum(y[self.diag_idxs])

pymc3/sampling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,14 @@ def _update_start_vals(a, b, model):
478478
"""Update a with b, without overwriting existing keys. Values specified for
479479
transformed variables on the original scale are also transformed and inserted.
480480
"""
481-
for name in a:
482-
for tname in b:
483-
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
484-
transform_func = [d.transformation for d in model.deterministics if d.name == name]
485-
if transform_func:
486-
b[tname] = transform_func[0].forward(a[name]).eval()
481+
if model is not None:
482+
for free_RV in model.free_RVs:
483+
tname = free_RV.name
484+
for name in a:
485+
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
486+
transform_func = [d.transformation for d in model.deterministics if d.name == name]
487+
if transform_func:
488+
b[tname] = transform_func[0].forward_val(a[name], point=b).eval()
487489

488490
a.update({k: v for k, v in b.items() if k not in a})
489491

pymc3/tests/test_sampling.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,19 @@ def test_soft_update_transformed(self):
150150
pm.sampling._update_start_vals(start, test_point, model)
151151
assert_almost_equal(np.exp(start['a_log__']), start['a'])
152152

153+
def test_soft_update_transformed2(self):
154+
with pm.Model() as model:
155+
a = pm.Uniform('a', lower=0., upper=1.)
156+
pm.Uniform('b', lower=0., upper=1.-a)
157+
start = {'a': .3, 'b': .5}
158+
test_point = {'a_interval__': -0.8472978603872037,
159+
'b_interval__': 0.9162907318741552}
160+
pm.sampling._update_start_vals(start, model.test_point, model)
161+
assert_almost_equal(start['a_interval__'],
162+
test_point['a_interval__'])
163+
assert_almost_equal(start['b_interval__'],
164+
test_point['b_interval__'])
165+
153166

154167
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
155168
class TestNamedSampling(SeededTest):

0 commit comments

Comments
 (0)