diff --git a/pymc3/model.py b/pymc3/model.py index 77635278b8..a9ca9e2620 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -23,6 +23,55 @@ FlatView = collections.namedtuple('FlatView', 'input, replacements, view') +def get_transformed_name(name, transform): + """ + Consistent way of transforming names + + Parameters + ---------- + name : str + Name to transform + transform : object + Should be a subclass of `transforms.Transform` + + Returns: + A string to use for the transformed variable + """ + return "{}_{}__".format(name, transform.name) + + +def is_transformed_name(name): + """ + Quickly check if a name was transformed with `get_transormed_name` + + Parameters + ---------- + name : str + Name to check + + Returns: + Boolean, whether the string could have been produced by `get_transormed_name` + """ + return name.endswith('__') and name.count('_') >= 3 + + +def get_untransformed_name(name): + """ + Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed + + Parameters + ---------- + name : str + Name to untransform + + Returns: + String with untransformed version of the name. + """ + if not is_transformed_name(name): + raise ValueError(u'{} does not appear to be a transformed name'.format(name)) + return '_'.join(name.split('_')[:-3]) + + class InstanceMethod(object): """Class for hiding references to instance methods so they can be pickled. @@ -516,7 +565,7 @@ def Var(self, name, dist, data=None, total_size=None): ' and added transformed {orig_name} to model.'.format( transform=dist.transform.name, name=name, - orig_name='{}_{}_'.format(name, dist.transform.name))) + orig_name=get_transformed_name(name, dist.transform))) self.deterministics.append(var) return var elif isinstance(data, dict): @@ -989,13 +1038,13 @@ def __init__(self, type=None, owner=None, index=None, name=None, if type is None: type = distribution.type super(TransformedRV, self).__init__(type, owner, index, name) - + self.transformation = transform if distribution is not None: self.model = model - transformed_name = "{}_{}_".format(name, transform.name) + transformed_name = get_transformed_name(name, transform) self.transformed = model.Var( transformed_name, transform.apply(distribution), total_size=total_size) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index d02467188c..8e3ffa7fab 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -7,7 +7,7 @@ import pymc3 as pm from .backends.base import merge_traces, BaseTrace, MultiTrace from .backends.ndarray import NDArray -from .model import modelcontext, Point +from .model import modelcontext, Point, is_transformed_name, get_untransformed_name from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep) @@ -462,14 +462,13 @@ def _update_start_vals(a, b, model): """Update a with b, without overwriting existing keys. Values specified for transformed variables on the original scale are also transformed and inserted. """ - for name in a: for tname in b: - if tname.startswith(name) and tname!=name: - transform_func = [d.transformation for d in model.deterministics if d.name==name] + if is_transformed_name(tname) and get_untransformed_name(tname) == name: + transform_func = [d.transformation for d in model.deterministics if d.name == name] if transform_func: b[tname] = transform_func[0].forward(a[name]).eval() - + a.update({k: v for k, v in b.items() if k not in a}) def sample_ppc(trace, samples=None, model=None, vars=None, size=None, diff --git a/pymc3/tests/test_advi.py b/pymc3/tests/test_advi.py index 08541dd37d..e4f9e63f03 100644 --- a/pymc3/tests/test_advi.py +++ b/pymc3/tests/test_advi.py @@ -264,4 +264,4 @@ def test_sample_vp(self): trace = sample_vp(v_params, draws=1, hide_transformed=True) assert trace.varnames == ['p'] trace = sample_vp(v_params, draws=1, hide_transformed=False) - assert sorted(trace.varnames) == ['p', 'p_logodds_'] + assert sorted(trace.varnames) == ['p', 'p_logodds__'] diff --git a/pymc3/tests/test_diagnostics.py b/pymc3/tests/test_diagnostics.py index 9f4dd578d2..3f62d859c6 100644 --- a/pymc3/tests/test_diagnostics.py +++ b/pymc3/tests/test_diagnostics.py @@ -19,10 +19,10 @@ def get_ptrace(self, n_samples): model = build_disaster_model() with model: # Run sampler - step1 = Slice([model.early_mean_log_, model.late_mean_log_]) + step1 = Slice([model.early_mean_log__, model.late_mean_log__]) step2 = Metropolis([model.switchpoint]) start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10} - ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2, + ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2, progressbar=False, random_seed=[20090425, 19700903]) return ptrace @@ -91,7 +91,7 @@ def get_switchpoint(self, n_samples): model = build_disaster_model() with model: # Run sampler - step1 = Slice([model.early_mean_log_, model.late_mean_log_]) + step1 = Slice([model.early_mean_log__, model.late_mean_log__]) step2 = Metropolis([model.switchpoint]) trace = sample(n_samples, step=[step1, step2], progressbar=False, random_seed=1) return trace['switchpoint'] diff --git a/pymc3/tests/test_dist_math.py b/pymc3/tests/test_dist_math.py index 81d9a5c1c1..7dada11a8d 100644 --- a/pymc3/tests/test_dist_math.py +++ b/pymc3/tests/test_dist_math.py @@ -111,5 +111,5 @@ def test_multinomial_bound(): p_b = pm.Dirichlet('p', floatX(np.ones(2))) MultinomialB('x', n, p_b, observed=x) - assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}), - modelB.logp({'p_stickbreaking_': [0]})) + assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}), + modelB.logp({'p_stickbreaking__': [0]})) diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index ca3d328f60..1ec316ec14 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -79,14 +79,14 @@ def build_model(self): def too_slow(self): model = self.build_model() start = {'groupmean': self.obs_means.mean(), - 'groupsd_interval_': 0, - 'sd_interval_': 0, + 'groupsd_interval__': 0, + 'sd_interval__': 0, 'means': self.obs_means, 'floor_m': 0., } with model: start = pm.find_MAP(start=start, - vars=[model['groupmean'], model['sd_interval_'], model['floor_m']]) + vars=[model['groupmean'], model['sd_interval__'], model['floor_m']]) step = pm.NUTS(model.vars, scaling=start) pm.sample(50, step=step, start=start) @@ -117,8 +117,8 @@ def too_slow(self): with model: start = pm.Point({ 'groupmean': self.obs_means.mean(), - 'groupsd_interval_': 0, - 'sd_interval_': 0, + 'groupsd_interval__': 0, + 'sd_interval__': 0, 'means': np.array(self.obs_means), 'u_m': np.array([.72]), 'floor_m': 0., @@ -168,7 +168,7 @@ def test_disaster_model(self): # Initial values for stochastic nodes start = {'early_mean': 2., 'late_mean': 3.} # Use slice sampler for means (other varibles auto-selected) - step = pm.Slice([model.early_mean_log_, model.late_mean_log_]) + step = pm.Slice([model.early_mean_log__, model.late_mean_log__]) tr = pm.sample(500, tune=50, start=start, step=step) pm.summary(tr) @@ -178,7 +178,7 @@ def test_disaster_model_missing(self): # Initial values for stochastic nodes start = {'early_mean': 2., 'late_mean': 3.} # Use slice sampler for means (other varibles auto-selected) - step = pm.Slice([model.early_mean_log_, model.late_mean_log_]) + step = pm.Slice([model.early_mean_log__, model.late_mean_log__]) tr = pm.sample(500, tune=50, start=start, step=step) pm.summary(tr) @@ -260,7 +260,7 @@ def test_run(self): model = self.build_model() with model: start = {'psi': 0.5, 'z': (self.y > 0).astype(int), 'theta': 5} - step_one = pm.Metropolis([model.theta_interval_, model.psi_logodds_]) + step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__]) step_two = pm.BinaryMetropolis([model.z]) pm.sample(50, step=[step_one, step_two], start=start) diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index 453463e6a3..887b498708 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -3,6 +3,7 @@ import numpy as np import pymc3 as pm from pymc3.distributions import HalfCauchy, Normal +from pymc3.distributions.transforms import Transform from pymc3 import Potential, Deterministic from pymc3.theanof import generator from .helpers import select_by_precision @@ -193,3 +194,28 @@ def test_gradient_with_scaling(self): g1 = grad1(1) g2 = grad2(1) np.testing.assert_almost_equal(g1, g2) + + +class TestTransformName(object): + cases = [ + ('var', 'var_test__'), + ('var_test_', 'var_test__test__') + ] + transform_name = 'test' + + def test_get_transformed_name(self): + test_transform = Transform() + test_transform.name = self.transform_name + for name, transformed in self.cases: + assert pm.model.get_transformed_name(name, test_transform) == transformed + + def test_is_transformed_name(self): + for name, transformed in self.cases: + assert pm.model.is_transformed_name(transformed) + assert not pm.model.is_transformed_name(name) + + def test_get_untransformed_name(self): + for name, transformed in self.cases: + assert pm.model.get_untransformed_name(transformed) == name + with pytest.raises(ValueError): + pm.model.get_untransformed_name(name) diff --git a/pymc3/tests/test_models_linear.py b/pymc3/tests/test_models_linear.py index df1958bc8b..4b89293c45 100644 --- a/pymc3/tests/test_models_linear.py +++ b/pymc3/tests/test_models_linear.py @@ -30,7 +30,7 @@ def setup_class(cls): def test_linear_component(self): vars_to_create = { - 'sigma_interval_', + 'sigma_interval__', 'y_obs', 'lm_x0', 'lm_Intercept' @@ -41,7 +41,7 @@ def test_linear_component(self): self.data_linear['y'], name='lm' ) # yields lm_x0, lm_Intercept - sigma = Uniform('sigma', 0, 20) # yields sigma_interval_ + sigma = Uniform('sigma', 0, 20) # yields sigma_interval__ Normal('y_obs', mu=lm.y_est, sd=sigma, observed=self.y_linear) # yields y_obs start = find_MAP(vars=[sigma]) step = Slice(model.vars) @@ -68,7 +68,7 @@ def test_linear_component_from_formula(self): def test_glm(self): with Model() as model: vars_to_create = { - 'glm_sd_log_', + 'glm_sd_log__', 'glm_y', 'glm_x0', 'glm_Intercept' diff --git a/pymc3/tests/test_plots.py b/pymc3/tests/test_plots.py index 16a260559c..3229b5ee9d 100644 --- a/pymc3/tests/test_plots.py +++ b/pymc3/tests/test_plots.py @@ -57,7 +57,7 @@ def test_multichain_plots(): model = build_disaster_model() with model: # Run sampler - step1 = Slice([model.early_mean_log_, model.late_mean_log_]) + step1 = Slice([model.early_mean_log__, model.late_mean_log__]) step2 = Metropolis([model.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} ptrace = sample(1000, step=[step1, step2], start=start, njobs=2) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index bc19d68eae..79005b3378 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -119,12 +119,12 @@ def test_soft_update_empty(self): test_point = {'a': 3, 'b': 4} pm.sampling._soft_update(start, test_point) assert start == test_point - + def test_soft_update_transformed(self): start = {'a': 2} - test_point = {'a_log_': 0} + test_point = {'a_log__': 0} pm.sampling._soft_update(start, test_point) - assert assert_almost_equal(start['a_log_'], np.log(start['a'])) + assert assert_almost_equal(start['a_log__'], np.log(start['a'])) class TestNamedSampling(SeededTest): diff --git a/pymc3/tests/test_stats.py b/pymc3/tests/test_stats.py index 92d66a8978..98badc85df 100644 --- a/pymc3/tests/test_stats.py +++ b/pymc3/tests/test_stats.py @@ -378,5 +378,5 @@ def test_row_names(self): step = Metropolis() trace = pm.sample(100, step=step) ds = df_summary(trace, batches=3, include_transformed=True) - npt.assert_equal(np.array(['x_interval_', 'x']), + npt.assert_equal(np.array(['x_interval__', 'x']), ds.index) diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index 2bf647e609..7170b4b777 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -116,7 +116,7 @@ def test_sample(self): assert trace.varnames == ['p'] assert len(trace) == 1 trace = app.sample(draws=10, hide_transformed=False) - assert sorted(trace.varnames) == ['p', 'p_logodds_'] + assert sorted(trace.varnames) == ['p', 'p_logodds__'] assert len(trace) == 10 def test_sample_node(self):