Skip to content

Commit 6b39941

Browse files
authored
Merge pull request #2046 from pymc-devs/transformed_init_fix
Ensure auto-transformed variables use specified starting values
2 parents ec0cd55 + 6117810 commit 6b39941

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

pymc3/distributions/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from ..math import logit, invlogit
77
import numpy as np
88

9-
__all__ = ['transform', 'stick_breaking', 'logodds',
10-
'log', 'sum_to_1', 't_stick_breaking']
9+
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
10+
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
1111

1212

1313
class Transform(object):

pymc3/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,11 +989,14 @@ def __init__(self, type=None, owner=None, index=None, name=None,
989989
if type is None:
990990
type = distribution.type
991991
super(TransformedRV, self).__init__(type, owner, index, name)
992+
993+
self.transformation = transform
992994

993995
if distribution is not None:
994996
self.model = model
995997

996998
transformed_name = "{}_{}_".format(name, transform.name)
999+
9971000
self.transformed = model.Var(
9981001
transformed_name, transform.apply(distribution), total_size=total_size)
9991002

pymc3/sampling.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
348348
strace = _choose_backend(trace, chain, model=model)
349349

350350
if len(strace) > 0:
351-
_soft_update(start, strace.point(-1))
351+
_update_start_vals(start, strace.point(-1), model)
352352
else:
353-
_soft_update(start, model.test_point)
353+
_update_start_vals(start, model.test_point, model)
354354

355355
try:
356356
step = CompoundStep(step)
@@ -457,13 +457,19 @@ def stop_tuning(step):
457457

458458
return step
459459

460-
461-
def _soft_update(a, b):
462-
"""As opposed to dict.update, don't overwrite keys if present.
460+
def _update_start_vals(a, b, model):
461+
"""Update a with b, without overwriting existing keys. Values specified for
462+
transformed variables on the original scale are also transformed and inserted.
463463
"""
464+
465+
for name in a:
466+
for tname in b:
467+
if tname.startswith(name) and tname!=name:
468+
transform_func = [d.transformation for d in model.deterministics if d.name==name][0]
469+
b[tname] = transform_func.forward(a[name]).eval()
470+
464471
a.update({k: v for k, v in b.items() if k not in a})
465472

466-
467473
def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
468474
random_seed=None, progressbar=True):
469475
"""Generate posterior predictive samples from a model given a trace.

pymc3/tests/test_diagnostics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_ptrace(self, n_samples):
2121
# Run sampler
2222
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
2323
step2 = Metropolis([model.switchpoint])
24-
start = {'early_mean': 7., 'late_mean': 1., 'switchpoint': 100}
24+
start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10}
2525
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
2626
progressbar=False, random_seed=[20090425, 19700903])
2727
return ptrace
@@ -35,15 +35,15 @@ def test_good(self):
3535

3636
def test_bad(self):
3737
"""Confirm Gelman-Rubin statistic is far from 1 for a small number of samples."""
38-
n_samples = 10
38+
n_samples = 5
3939
rhat = gelman_rubin(self.get_ptrace(n_samples))
4040
assert not all(1 / self.good_ratio < r <
4141
self.good_ratio for r in rhat.values())
4242

4343
def test_right_shape_python_float(self, shape=None, test_shape=None):
4444
"""Check Gelman-Rubin statistic shape is correct w/ python float"""
4545
n_jobs = 3
46-
n_samples = 10
46+
n_samples = 5
4747

4848
with Model():
4949
if shape is not None:

pymc3/tests/test_sampling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from itertools import combinations
22
import numpy as np
3+
from numpy.testing import assert_almost_equal
4+
35
try:
46
import unittest.mock as mock # py3
57
except ImportError:
@@ -117,6 +119,12 @@ def test_soft_update_empty(self):
117119
test_point = {'a': 3, 'b': 4}
118120
pm.sampling._soft_update(start, test_point)
119121
assert start == test_point
122+
123+
def test_soft_update_transformed(self):
124+
start = {'a': 2}
125+
test_point = {'a_log_': 0}
126+
pm.sampling._soft_update(start, test_point)
127+
assert assert_almost_equal(start['a_log_'], np.log(start['a']))
120128

121129

122130
class TestNamedSampling(SeededTest):

0 commit comments

Comments
 (0)