Skip to content

Ensure auto-transformed variables use specified starting values #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 20, 2017
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
01c4984
Added _transformed_init function to sampling, which ensures transform…
fonnesbeck Apr 16, 2017
288ab35
Added test for _transformed_init in sample
fonnesbeck Apr 16, 2017
84e908e
Added example in _transformed_init docstring
fonnesbeck Apr 16, 2017
b9c61fc
Added interval to __all__ in transforms
fonnesbeck Apr 16, 2017
3ac6735
Fixed namespace error
fonnesbeck Apr 16, 2017
0ee34be
Encoded bound into interval transforms so that starting values can be…
fonnesbeck Apr 16, 2017
1a748eb
Simplified transformation of start values by carrying model forward i…
fonnesbeck Apr 17, 2017
555bc27
Added _transformed_init function to sampling, which ensures transform…
fonnesbeck Apr 16, 2017
c60a914
Added test for _transformed_init in sample
fonnesbeck Apr 16, 2017
644a5b6
Added example in _transformed_init docstring
fonnesbeck Apr 16, 2017
5345811
Added interval to __all__ in transforms
fonnesbeck Apr 16, 2017
4fd9029
Fixed namespace error
fonnesbeck Apr 16, 2017
ecf980f
Encoded bound into interval transforms so that starting values can be…
fonnesbeck Apr 16, 2017
c54f975
Simplified transformation of start values by carrying model forward i…
fonnesbeck Apr 17, 2017
d86031d
Merge branch 'transformed_init_fix' of github.com:pymc-devs/pymc3 int…
fonnesbeck Apr 17, 2017
b9436f0
Attempt to fix spurious 2.7 diagnostic test error
fonnesbeck Apr 19, 2017
863079e
Better fix for failing test
fonnesbeck Apr 19, 2017
b1866e9
Tweaks to fix stochastic diagnostic test error
fonnesbeck Apr 20, 2017
792dd0d
Merge branch 'master' into transformed_init_fix
fonnesbeck Apr 20, 2017
6117810
Renamed _soft_update to _updates_start_vals
fonnesbeck Apr 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from ..math import logit, invlogit
import numpy as np

__all__ = ['transform', 'stick_breaking', 'logodds',
'log', 'sum_to_1', 't_stick_breaking']
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']


class Transform(object):
Expand Down
3 changes: 3 additions & 0 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,11 +989,14 @@ def __init__(self, type=None, owner=None, index=None, name=None,
if type is None:
type = distribution.type
super(TransformedRV, self).__init__(type, owner, index, name)

self.transformation = transform

if distribution is not None:
self.model = model

transformed_name = "{}_{}_".format(name, transform.name)

self.transformed = model.Var(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we monkey-patch lower and upper onto self.transformed? While not clean, it would be cleaner than this.

transformed_name, transform.apply(distribution), total_size=total_size)

Expand Down
18 changes: 12 additions & 6 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
strace = _choose_backend(trace, chain, model=model)

if len(strace) > 0:
_soft_update(start, strace.point(-1))
_soft_update(start, strace.point(-1), model)
else:
_soft_update(start, model.test_point)
_soft_update(start, model.test_point, model)

try:
step = CompoundStep(step)
Expand Down Expand Up @@ -457,13 +457,19 @@ def stop_tuning(step):

return step


def _soft_update(a, b):
"""As opposed to dict.update, don't overwrite keys if present.
def _soft_update(a, b, model):
"""Update a with b, without overwriting existing keys. Values specified for
transformed variables on the original scale are also transformed and inserted.
"""

for name in a:
for tname in b:
if tname.startswith(name) and tname!=name:
transform_func = [d.transformation for d in model.deterministics if d.name==name][0]
b[tname] = transform_func.forward(a[name]).eval()

a.update({k: v for k, v in b.items() if k not in a})


def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
random_seed=None, progressbar=True):
"""Generate posterior predictive samples from a model given a trace.
Expand Down
6 changes: 3 additions & 3 deletions pymc3/tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_ptrace(self, n_samples):
# Run sampler
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
step2 = Metropolis([model.switchpoint])
start = {'early_mean': 7., 'late_mean': 1., 'switchpoint': 100}
start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10}
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
progressbar=False, random_seed=[20090425, 19700903])
return ptrace
Expand All @@ -35,15 +35,15 @@ def test_good(self):

def test_bad(self):
"""Confirm Gelman-Rubin statistic is far from 1 for a small number of samples."""
n_samples = 10
n_samples = 5
rhat = gelman_rubin(self.get_ptrace(n_samples))
assert not all(1 / self.good_ratio < r <
self.good_ratio for r in rhat.values())

def test_right_shape_python_float(self, shape=None, test_shape=None):
"""Check Gelman-Rubin statistic shape is correct w/ python float"""
n_jobs = 3
n_samples = 10
n_samples = 5

with Model():
if shape is not None:
Expand Down
8 changes: 8 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from itertools import combinations
import numpy as np
from numpy.testing import assert_almost_equal

try:
import unittest.mock as mock # py3
except ImportError:
Expand Down Expand Up @@ -117,6 +119,12 @@ def test_soft_update_empty(self):
test_point = {'a': 3, 'b': 4}
pm.sampling._soft_update(start, test_point)
assert start == test_point

def test_soft_update_transformed(self):
start = {'a': 2}
test_point = {'a_log_': 0}
pm.sampling._soft_update(start, test_point)
assert assert_almost_equal(start['a_log_'], np.log(start['a']))


class TestNamedSampling(SeededTest):
Expand Down