From 01c4984e52ef48317758eccd468c0275373d2e8c Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 13:58:32 -0500 Subject: [PATCH 01/18] Added _transformed_init function to sampling, which ensures transformed variables are initialized with values specified by the user on the original variable --- pymc3/sampling.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 5936f1c143..8cd6f87505 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -12,6 +12,7 @@ BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) from .plots.traceplot import traceplot +import pymc3.distributions as distributions from tqdm import tqdm import sys @@ -457,13 +458,26 @@ def stop_tuning(step): return step +def _transformed_init(a, b): + """Transforms original starting values for transformed variables specified by + the user (dict a) and inserts them into the init dict (dict b). + """ + + for name in a: + for tname in b: + if tname.startswith(name) and tname!=name: + transform = tname.split(name)[-1][1:-1] + transform_func = distributions.__dict__[transform] + b[tname] = transform_func.forward(a[name]).eval() + + return b def _soft_update(a, b): """As opposed to dict.update, don't overwrite keys if present. """ + b = _transformed_init(a, b) a.update({k: v for k, v in b.items() if k not in a}) - def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_seed=None, progressbar=True): """Generate posterior predictive samples from a model given a trace. From 288ab3524f6b5b5ea680c95633d21153d2d66b47 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 14:07:22 -0500 Subject: [PATCH 02/18] Added test for _transformed_init in sample --- pymc3/tests/test_sampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 3fd5f94041..bc19d68eae 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -1,5 +1,7 @@ from itertools import combinations import numpy as np +from numpy.testing import assert_almost_equal + try: import unittest.mock as mock # py3 except ImportError: @@ -117,6 +119,12 @@ def test_soft_update_empty(self): test_point = {'a': 3, 'b': 4} pm.sampling._soft_update(start, test_point) assert start == test_point + + def test_soft_update_transformed(self): + start = {'a': 2} + test_point = {'a_log_': 0} + pm.sampling._soft_update(start, test_point) + assert assert_almost_equal(start['a_log_'], np.log(start['a'])) class TestNamedSampling(SeededTest): From 84e908ed11531e73dc0afef6714bbf3fba68590d Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 14:12:14 -0500 Subject: [PATCH 03/18] Added example in _transformed_init docstring --- pymc3/sampling.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 8cd6f87505..65ff4384f5 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -461,6 +461,15 @@ def stop_tuning(step): def _transformed_init(a, b): """Transforms original starting values for transformed variables specified by the user (dict a) and inserts them into the init dict (dict b). + + Examples + -------- + .. code:: ipython + + >>> a = {'alpha': 5.0} + >>> b = {'alpha_log_': 0} + >>> _transformed_init(a, b) + {'alpha_log_': array(1.6094379425048828, dtype=float32)} """ for name in a: From b9c61fc5c573707e8fda2f569844956308e9f7c0 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 15:58:55 -0500 Subject: [PATCH 04/18] Added interval to __all__ in transforms --- pymc3/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 7b9e49f90b..8be405f7fb 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -6,7 +6,7 @@ from ..math import logit, invlogit import numpy as np -__all__ = ['transform', 'stick_breaking', 'logodds', +__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'log', 'sum_to_1', 't_stick_breaking'] From 3ac6735c6cf47c323de3d29f4ddbbf6e8441d2b9 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 16:04:06 -0500 Subject: [PATCH 05/18] Fixed namespace error --- pymc3/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 65ff4384f5..13ed2090b3 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -476,7 +476,7 @@ def _transformed_init(a, b): for tname in b: if tname.startswith(name) and tname!=name: transform = tname.split(name)[-1][1:-1] - transform_func = distributions.__dict__[transform] + transform_func = distributions.transforms.__dict__[transform] b[tname] = transform_func.forward(a[name]).eval() return b From 0ee34beb42d3ad2938b054e940867c402fa05788 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 17:02:18 -0500 Subject: [PATCH 06/18] Encoded bound into interval transforms so that starting values can be properly transformed --- pymc3/distributions/transforms.py | 2 +- pymc3/model.py | 8 ++++++++ pymc3/sampling.py | 26 ++++++++++++++++++++------ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 8be405f7fb..80175d8b5d 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -7,7 +7,7 @@ import numpy as np __all__ = ['transform', 'stick_breaking', 'logodds', 'interval', - 'log', 'sum_to_1', 't_stick_breaking'] + 'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking'] class Transform(object): diff --git a/pymc3/model.py b/pymc3/model.py index c7b46e89d5..7abb2fba39 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -994,6 +994,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, self.model = model transformed_name = "{}_{}_".format(name, transform.name) + bound_dict = {'a':None, 'b':None} + for key in bound_dict: + try: + bound_dict[key] = transform.__dict__[key].eval() + except KeyError: + pass + if bound_dict['a'] is not None or bound_dict['b'] is not None: + transformed_name += '({},{})_'.format(bound_dict['a'], bound_dict['b']) self.transformed = model.Var( transformed_name, transform.apply(distribution), total_size=total_size) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 13ed2090b3..e61bc4b5e7 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -458,9 +458,9 @@ def stop_tuning(step): return step -def _transformed_init(a, b): +def _transformed_init(start_vals, other_start_vals): """Transforms original starting values for transformed variables specified by - the user (dict a) and inserts them into the init dict (dict b). + the user (dict1) and inserts them into the init dict (dict2). Examples -------- @@ -472,14 +472,28 @@ def _transformed_init(a, b): {'alpha_log_': array(1.6094379425048828, dtype=float32)} """ - for name in a: - for tname in b: + for name in start_vals: + for tname in other_start_vals: if tname.startswith(name) and tname!=name: transform = tname.split(name)[-1][1:-1] + bound_dict = {} + # Check if bounds are encoded in name + if transform.find('(') > -1: + transform, interval = transform.split('_(') + a, b = interval[:-1].split(',') + if a: + bound_dict['a'] = float(a) + if b: + bound_dict['b'] = float(b) transform_func = distributions.transforms.__dict__[transform] - b[tname] = transform_func.forward(a[name]).eval() + try: + # Some transformations need to be instantiated + transform_func = transform_func(**bound_dict) + except TypeError: + pass + other_start_vals[tname] = transform_func.forward(start_vals[name]).eval() - return b + return other_start_vals def _soft_update(a, b): """As opposed to dict.update, don't overwrite keys if present. From 1a748ebb39a69949ace109c16bcb88e429721fa5 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Mon, 17 Apr 2017 09:22:35 -0500 Subject: [PATCH 07/18] Simplified transformation of start values by carrying model forward into _soft_update --- pymc3/model.py | 11 +++------- pymc3/sampling.py | 53 ++++++++++------------------------------------- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/pymc3/model.py b/pymc3/model.py index 7abb2fba39..77635278b8 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -989,19 +989,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, if type is None: type = distribution.type super(TransformedRV, self).__init__(type, owner, index, name) + + self.transformation = transform if distribution is not None: self.model = model transformed_name = "{}_{}_".format(name, transform.name) - bound_dict = {'a':None, 'b':None} - for key in bound_dict: - try: - bound_dict[key] = transform.__dict__[key].eval() - except KeyError: - pass - if bound_dict['a'] is not None or bound_dict['b'] is not None: - transformed_name += '({},{})_'.format(bound_dict['a'], bound_dict['b']) + self.transformed = model.Var( transformed_name, transform.apply(distribution), total_size=total_size) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index e61bc4b5e7..81f2ce047c 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -12,7 +12,6 @@ BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) from .plots.traceplot import traceplot -import pymc3.distributions as distributions from tqdm import tqdm import sys @@ -349,9 +348,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, strace = _choose_backend(trace, chain, model=model) if len(strace) > 0: - _soft_update(start, strace.point(-1)) + _soft_update(start, strace.point(-1), model) else: - _soft_update(start, model.test_point) + _soft_update(start, model.test_point, model) try: step = CompoundStep(step) @@ -458,47 +457,17 @@ def stop_tuning(step): return step -def _transformed_init(start_vals, other_start_vals): - """Transforms original starting values for transformed variables specified by - the user (dict1) and inserts them into the init dict (dict2). - - Examples - -------- - .. code:: ipython - - >>> a = {'alpha': 5.0} - >>> b = {'alpha_log_': 0} - >>> _transformed_init(a, b) - {'alpha_log_': array(1.6094379425048828, dtype=float32)} +def _soft_update(a, b, model): + """Update a with b, without overwriting existing keys. Values specified for + transformed variables on the original scale are also transformed and inserted. """ - - for name in start_vals: - for tname in other_start_vals: + + for name in a: + for tname in b: if tname.startswith(name) and tname!=name: - transform = tname.split(name)[-1][1:-1] - bound_dict = {} - # Check if bounds are encoded in name - if transform.find('(') > -1: - transform, interval = transform.split('_(') - a, b = interval[:-1].split(',') - if a: - bound_dict['a'] = float(a) - if b: - bound_dict['b'] = float(b) - transform_func = distributions.transforms.__dict__[transform] - try: - # Some transformations need to be instantiated - transform_func = transform_func(**bound_dict) - except TypeError: - pass - other_start_vals[tname] = transform_func.forward(start_vals[name]).eval() - - return other_start_vals - -def _soft_update(a, b): - """As opposed to dict.update, don't overwrite keys if present. - """ - b = _transformed_init(a, b) + transform_func = [d.transformation for d in model.deterministics if d.name==name][0] + b[tname] = transform_func.forward(a[name]).eval() + a.update({k: v for k, v in b.items() if k not in a}) def sample_ppc(trace, samples=None, model=None, vars=None, size=None, From 555bc27312fe92f8ba868ee804fc4c729fa3ec20 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 13:58:32 -0500 Subject: [PATCH 08/18] Added _transformed_init function to sampling, which ensures transformed variables are initialized with values specified by the user on the original variable --- pymc3/sampling.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 5936f1c143..8cd6f87505 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -12,6 +12,7 @@ BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) from .plots.traceplot import traceplot +import pymc3.distributions as distributions from tqdm import tqdm import sys @@ -457,13 +458,26 @@ def stop_tuning(step): return step +def _transformed_init(a, b): + """Transforms original starting values for transformed variables specified by + the user (dict a) and inserts them into the init dict (dict b). + """ + + for name in a: + for tname in b: + if tname.startswith(name) and tname!=name: + transform = tname.split(name)[-1][1:-1] + transform_func = distributions.__dict__[transform] + b[tname] = transform_func.forward(a[name]).eval() + + return b def _soft_update(a, b): """As opposed to dict.update, don't overwrite keys if present. """ + b = _transformed_init(a, b) a.update({k: v for k, v in b.items() if k not in a}) - def sample_ppc(trace, samples=None, model=None, vars=None, size=None, random_seed=None, progressbar=True): """Generate posterior predictive samples from a model given a trace. From c60a91429cffa4d9e8b81ea387bae3278e21d45a Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 14:07:22 -0500 Subject: [PATCH 09/18] Added test for _transformed_init in sample --- pymc3/tests/test_sampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 3fd5f94041..bc19d68eae 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -1,5 +1,7 @@ from itertools import combinations import numpy as np +from numpy.testing import assert_almost_equal + try: import unittest.mock as mock # py3 except ImportError: @@ -117,6 +119,12 @@ def test_soft_update_empty(self): test_point = {'a': 3, 'b': 4} pm.sampling._soft_update(start, test_point) assert start == test_point + + def test_soft_update_transformed(self): + start = {'a': 2} + test_point = {'a_log_': 0} + pm.sampling._soft_update(start, test_point) + assert assert_almost_equal(start['a_log_'], np.log(start['a'])) class TestNamedSampling(SeededTest): From 644a5b6ead7d3c18f706032aac193d6940fe388c Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 14:12:14 -0500 Subject: [PATCH 10/18] Added example in _transformed_init docstring --- pymc3/sampling.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 8cd6f87505..65ff4384f5 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -461,6 +461,15 @@ def stop_tuning(step): def _transformed_init(a, b): """Transforms original starting values for transformed variables specified by the user (dict a) and inserts them into the init dict (dict b). + + Examples + -------- + .. code:: ipython + + >>> a = {'alpha': 5.0} + >>> b = {'alpha_log_': 0} + >>> _transformed_init(a, b) + {'alpha_log_': array(1.6094379425048828, dtype=float32)} """ for name in a: From 5345811b9a6039d3448511f5433e2fcd48d7ff44 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 15:58:55 -0500 Subject: [PATCH 11/18] Added interval to __all__ in transforms --- pymc3/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 7b9e49f90b..8be405f7fb 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -6,7 +6,7 @@ from ..math import logit, invlogit import numpy as np -__all__ = ['transform', 'stick_breaking', 'logodds', +__all__ = ['transform', 'stick_breaking', 'logodds', 'interval', 'log', 'sum_to_1', 't_stick_breaking'] From 4fd90295cfd821dc5b92be4aac81de2d0af2ce41 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 16:04:06 -0500 Subject: [PATCH 12/18] Fixed namespace error --- pymc3/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 65ff4384f5..13ed2090b3 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -476,7 +476,7 @@ def _transformed_init(a, b): for tname in b: if tname.startswith(name) and tname!=name: transform = tname.split(name)[-1][1:-1] - transform_func = distributions.__dict__[transform] + transform_func = distributions.transforms.__dict__[transform] b[tname] = transform_func.forward(a[name]).eval() return b From ecf980f81704743837bf62eda6cf2a41b25c9a98 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Sun, 16 Apr 2017 17:02:18 -0500 Subject: [PATCH 13/18] Encoded bound into interval transforms so that starting values can be properly transformed --- pymc3/distributions/transforms.py | 2 +- pymc3/model.py | 8 ++++++++ pymc3/sampling.py | 26 ++++++++++++++++++++------ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 8be405f7fb..80175d8b5d 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -7,7 +7,7 @@ import numpy as np __all__ = ['transform', 'stick_breaking', 'logodds', 'interval', - 'log', 'sum_to_1', 't_stick_breaking'] + 'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking'] class Transform(object): diff --git a/pymc3/model.py b/pymc3/model.py index c7b46e89d5..7abb2fba39 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -994,6 +994,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, self.model = model transformed_name = "{}_{}_".format(name, transform.name) + bound_dict = {'a':None, 'b':None} + for key in bound_dict: + try: + bound_dict[key] = transform.__dict__[key].eval() + except KeyError: + pass + if bound_dict['a'] is not None or bound_dict['b'] is not None: + transformed_name += '({},{})_'.format(bound_dict['a'], bound_dict['b']) self.transformed = model.Var( transformed_name, transform.apply(distribution), total_size=total_size) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 13ed2090b3..e61bc4b5e7 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -458,9 +458,9 @@ def stop_tuning(step): return step -def _transformed_init(a, b): +def _transformed_init(start_vals, other_start_vals): """Transforms original starting values for transformed variables specified by - the user (dict a) and inserts them into the init dict (dict b). + the user (dict1) and inserts them into the init dict (dict2). Examples -------- @@ -472,14 +472,28 @@ def _transformed_init(a, b): {'alpha_log_': array(1.6094379425048828, dtype=float32)} """ - for name in a: - for tname in b: + for name in start_vals: + for tname in other_start_vals: if tname.startswith(name) and tname!=name: transform = tname.split(name)[-1][1:-1] + bound_dict = {} + # Check if bounds are encoded in name + if transform.find('(') > -1: + transform, interval = transform.split('_(') + a, b = interval[:-1].split(',') + if a: + bound_dict['a'] = float(a) + if b: + bound_dict['b'] = float(b) transform_func = distributions.transforms.__dict__[transform] - b[tname] = transform_func.forward(a[name]).eval() + try: + # Some transformations need to be instantiated + transform_func = transform_func(**bound_dict) + except TypeError: + pass + other_start_vals[tname] = transform_func.forward(start_vals[name]).eval() - return b + return other_start_vals def _soft_update(a, b): """As opposed to dict.update, don't overwrite keys if present. From c54f975e473233e7d64a3c90f68f29456bd4069a Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Mon, 17 Apr 2017 09:22:35 -0500 Subject: [PATCH 14/18] Simplified transformation of start values by carrying model forward into _soft_update --- pymc3/model.py | 11 +++------- pymc3/sampling.py | 53 ++++++++++------------------------------------- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/pymc3/model.py b/pymc3/model.py index 7abb2fba39..77635278b8 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -989,19 +989,14 @@ def __init__(self, type=None, owner=None, index=None, name=None, if type is None: type = distribution.type super(TransformedRV, self).__init__(type, owner, index, name) + + self.transformation = transform if distribution is not None: self.model = model transformed_name = "{}_{}_".format(name, transform.name) - bound_dict = {'a':None, 'b':None} - for key in bound_dict: - try: - bound_dict[key] = transform.__dict__[key].eval() - except KeyError: - pass - if bound_dict['a'] is not None or bound_dict['b'] is not None: - transformed_name += '({},{})_'.format(bound_dict['a'], bound_dict['b']) + self.transformed = model.Var( transformed_name, transform.apply(distribution), total_size=total_size) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index e61bc4b5e7..81f2ce047c 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -12,7 +12,6 @@ BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) from .plots.traceplot import traceplot -import pymc3.distributions as distributions from tqdm import tqdm import sys @@ -349,9 +348,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, strace = _choose_backend(trace, chain, model=model) if len(strace) > 0: - _soft_update(start, strace.point(-1)) + _soft_update(start, strace.point(-1), model) else: - _soft_update(start, model.test_point) + _soft_update(start, model.test_point, model) try: step = CompoundStep(step) @@ -458,47 +457,17 @@ def stop_tuning(step): return step -def _transformed_init(start_vals, other_start_vals): - """Transforms original starting values for transformed variables specified by - the user (dict1) and inserts them into the init dict (dict2). - - Examples - -------- - .. code:: ipython - - >>> a = {'alpha': 5.0} - >>> b = {'alpha_log_': 0} - >>> _transformed_init(a, b) - {'alpha_log_': array(1.6094379425048828, dtype=float32)} +def _soft_update(a, b, model): + """Update a with b, without overwriting existing keys. Values specified for + transformed variables on the original scale are also transformed and inserted. """ - - for name in start_vals: - for tname in other_start_vals: + + for name in a: + for tname in b: if tname.startswith(name) and tname!=name: - transform = tname.split(name)[-1][1:-1] - bound_dict = {} - # Check if bounds are encoded in name - if transform.find('(') > -1: - transform, interval = transform.split('_(') - a, b = interval[:-1].split(',') - if a: - bound_dict['a'] = float(a) - if b: - bound_dict['b'] = float(b) - transform_func = distributions.transforms.__dict__[transform] - try: - # Some transformations need to be instantiated - transform_func = transform_func(**bound_dict) - except TypeError: - pass - other_start_vals[tname] = transform_func.forward(start_vals[name]).eval() - - return other_start_vals - -def _soft_update(a, b): - """As opposed to dict.update, don't overwrite keys if present. - """ - b = _transformed_init(a, b) + transform_func = [d.transformation for d in model.deterministics if d.name==name][0] + b[tname] = transform_func.forward(a[name]).eval() + a.update({k: v for k, v in b.items() if k not in a}) def sample_ppc(trace, samples=None, model=None, vars=None, size=None, From b9436f09a71a35a2a6aac585ccafaa20a07f3bb2 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Wed, 19 Apr 2017 17:27:07 -0500 Subject: [PATCH 15/18] Attempt to fix spurious 2.7 diagnostic test error --- pymc3/tests/test_diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_diagnostics.py b/pymc3/tests/test_diagnostics.py index dba66c5340..33f243e902 100644 --- a/pymc3/tests/test_diagnostics.py +++ b/pymc3/tests/test_diagnostics.py @@ -21,7 +21,7 @@ def get_ptrace(self, n_samples): # Run sampler step1 = Slice([model.early_mean_log_, model.late_mean_log_]) step2 = Metropolis([model.switchpoint]) - start = {'early_mean': 7., 'late_mean': 1., 'switchpoint': 100} + start = {'early_mean': 3., 'late_mean': 1., 'switchpoint': 100} ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2, progressbar=False, random_seed=[20090425, 19700903]) return ptrace From 863079ee4d594d8cb28055ef3166b87c53c168c8 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Wed, 19 Apr 2017 17:30:52 -0500 Subject: [PATCH 16/18] Better fix for failing test --- pymc3/tests/test_diagnostics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc3/tests/test_diagnostics.py b/pymc3/tests/test_diagnostics.py index 33f243e902..78c3e6dca6 100644 --- a/pymc3/tests/test_diagnostics.py +++ b/pymc3/tests/test_diagnostics.py @@ -21,7 +21,7 @@ def get_ptrace(self, n_samples): # Run sampler step1 = Slice([model.early_mean_log_, model.late_mean_log_]) step2 = Metropolis([model.switchpoint]) - start = {'early_mean': 3., 'late_mean': 1., 'switchpoint': 100} + start = {'early_mean': 7., 'late_mean': 1., 'switchpoint': 100} ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2, progressbar=False, random_seed=[20090425, 19700903]) return ptrace @@ -35,7 +35,7 @@ def test_good(self): def test_bad(self): """Confirm Gelman-Rubin statistic is far from 1 for a small number of samples.""" - n_samples = 10 + n_samples = 5 rhat = gelman_rubin(self.get_ptrace(n_samples)) assert not all(1 / self.good_ratio < r < self.good_ratio for r in rhat.values()) @@ -43,7 +43,7 @@ def test_bad(self): def test_right_shape_python_float(self, shape=None, test_shape=None): """Check Gelman-Rubin statistic shape is correct w/ python float""" n_jobs = 3 - n_samples = 10 + n_samples = 5 with Model(): if shape is not None: From b1866e984912d727cbb9f682254911f2fba00b9a Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Wed, 19 Apr 2017 20:34:35 -0500 Subject: [PATCH 17/18] Tweaks to fix stochastic diagnostic test error --- pymc3/tests/test_diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_diagnostics.py b/pymc3/tests/test_diagnostics.py index 78c3e6dca6..9f4dd578d2 100644 --- a/pymc3/tests/test_diagnostics.py +++ b/pymc3/tests/test_diagnostics.py @@ -21,7 +21,7 @@ def get_ptrace(self, n_samples): # Run sampler step1 = Slice([model.early_mean_log_, model.late_mean_log_]) step2 = Metropolis([model.switchpoint]) - start = {'early_mean': 7., 'late_mean': 1., 'switchpoint': 100} + start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10} ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2, progressbar=False, random_seed=[20090425, 19700903]) return ptrace From 6117810cbcc91e040fd67160e2c1719f0b038db1 Mon Sep 17 00:00:00 2001 From: Christopher Fonnesbeck Date: Thu, 20 Apr 2017 07:33:21 -0500 Subject: [PATCH 18/18] Renamed _soft_update to _updates_start_vals --- pymc3/sampling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 81f2ce047c..ca93b0545c 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -348,9 +348,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, strace = _choose_backend(trace, chain, model=model) if len(strace) > 0: - _soft_update(start, strace.point(-1), model) + _update_start_vals(start, strace.point(-1), model) else: - _soft_update(start, model.test_point, model) + _update_start_vals(start, model.test_point, model) try: step = CompoundStep(step) @@ -457,7 +457,7 @@ def stop_tuning(step): return step -def _soft_update(a, b, model): +def _update_start_vals(a, b, model): """Update a with b, without overwriting existing keys. Values specified for transformed variables on the original scale are also transformed and inserted. """