Skip to content

Use consistent transform names everywhere #2089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,55 @@
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')


def get_transformed_name(name, transform):
"""
Consistent way of transforming names

Parameters
----------
name : str
Name to transform
transform : object
Should be a subclass of `transforms.Transform`

Returns:
A string to use for the transformed variable
"""
return "{}_{}__".format(name, transform.name)


def is_transformed_name(name):
"""
Quickly check if a name was transformed with `get_transormed_name`

Parameters
----------
name : str
Name to check

Returns:
Boolean, whether the string could have been produced by `get_transormed_name`
"""
return name.endswith('__') and name.count('_') >= 3


def get_untransformed_name(name):
"""
Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed

Parameters
----------
name : str
Name to untransform

Returns:
String with untransformed version of the name.
"""
if not is_transformed_name(name):
raise ValueError(u'{} does not appear to be a transformed name'.format(name))
return '_'.join(name.split('_')[:-3])


class InstanceMethod(object):
"""Class for hiding references to instance methods so they can be pickled.

Expand Down Expand Up @@ -516,7 +565,7 @@ def Var(self, name, dist, data=None, total_size=None):
' and added transformed {orig_name} to model.'.format(
transform=dist.transform.name,
name=name,
orig_name='{}_{}_'.format(name, dist.transform.name)))
orig_name=get_transformed_name(name, dist.transform)))
self.deterministics.append(var)
return var
elif isinstance(data, dict):
Expand Down Expand Up @@ -989,13 +1038,13 @@ def __init__(self, type=None, owner=None, index=None, name=None,
if type is None:
type = distribution.type
super(TransformedRV, self).__init__(type, owner, index, name)

self.transformation = transform

if distribution is not None:
self.model = model

transformed_name = "{}_{}_".format(name, transform.name)
transformed_name = get_transformed_name(name, transform)

self.transformed = model.Var(
transformed_name, transform.apply(distribution), total_size=total_size)
Expand Down
9 changes: 4 additions & 5 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pymc3 as pm
from .backends.base import merge_traces, BaseTrace, MultiTrace
from .backends.ndarray import NDArray
from .model import modelcontext, Point
from .model import modelcontext, Point, is_transformed_name, get_untransformed_name
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
Slice, CompoundStep)
Expand Down Expand Up @@ -462,14 +462,13 @@ def _update_start_vals(a, b, model):
"""Update a with b, without overwriting existing keys. Values specified for
transformed variables on the original scale are also transformed and inserted.
"""

for name in a:
for tname in b:
if tname.startswith(name) and tname!=name:
transform_func = [d.transformation for d in model.deterministics if d.name==name]
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
transform_func = [d.transformation for d in model.deterministics if d.name == name]
if transform_func:
b[tname] = transform_func[0].forward(a[name]).eval()

a.update({k: v for k, v in b.items() if k not in a})

def sample_ppc(trace, samples=None, model=None, vars=None, size=None,
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_advi.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ def test_sample_vp(self):
trace = sample_vp(v_params, draws=1, hide_transformed=True)
assert trace.varnames == ['p']
trace = sample_vp(v_params, draws=1, hide_transformed=False)
assert sorted(trace.varnames) == ['p', 'p_logodds_']
assert sorted(trace.varnames) == ['p', 'p_logodds__']
6 changes: 3 additions & 3 deletions pymc3/tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def get_ptrace(self, n_samples):
model = build_disaster_model()
with model:
# Run sampler
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
step2 = Metropolis([model.switchpoint])
start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10}
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
progressbar=False, random_seed=[20090425, 19700903])
return ptrace

Expand Down Expand Up @@ -91,7 +91,7 @@ def get_switchpoint(self, n_samples):
model = build_disaster_model()
with model:
# Run sampler
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
step2 = Metropolis([model.switchpoint])
trace = sample(n_samples, step=[step1, step2], progressbar=False, random_seed=1)
return trace['switchpoint']
Expand Down
4 changes: 2 additions & 2 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ def test_multinomial_bound():
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
MultinomialB('x', n, p_b, observed=x)

assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}),
modelB.logp({'p_stickbreaking_': [0]}))
assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),
modelB.logp({'p_stickbreaking__': [0]}))
16 changes: 8 additions & 8 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def build_model(self):
def too_slow(self):
model = self.build_model()
start = {'groupmean': self.obs_means.mean(),
'groupsd_interval_': 0,
'sd_interval_': 0,
'groupsd_interval__': 0,
'sd_interval__': 0,
'means': self.obs_means,
'floor_m': 0.,
}
with model:
start = pm.find_MAP(start=start,
vars=[model['groupmean'], model['sd_interval_'], model['floor_m']])
vars=[model['groupmean'], model['sd_interval__'], model['floor_m']])
step = pm.NUTS(model.vars, scaling=start)
pm.sample(50, step=step, start=start)

Expand Down Expand Up @@ -117,8 +117,8 @@ def too_slow(self):
with model:
start = pm.Point({
'groupmean': self.obs_means.mean(),
'groupsd_interval_': 0,
'sd_interval_': 0,
'groupsd_interval__': 0,
'sd_interval__': 0,
'means': np.array(self.obs_means),
'u_m': np.array([.72]),
'floor_m': 0.,
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_disaster_model(self):
# Initial values for stochastic nodes
start = {'early_mean': 2., 'late_mean': 3.}
# Use slice sampler for means (other varibles auto-selected)
step = pm.Slice([model.early_mean_log_, model.late_mean_log_])
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
tr = pm.sample(500, tune=50, start=start, step=step)
pm.summary(tr)

Expand All @@ -178,7 +178,7 @@ def test_disaster_model_missing(self):
# Initial values for stochastic nodes
start = {'early_mean': 2., 'late_mean': 3.}
# Use slice sampler for means (other varibles auto-selected)
step = pm.Slice([model.early_mean_log_, model.late_mean_log_])
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
tr = pm.sample(500, tune=50, start=start, step=step)
pm.summary(tr)

Expand Down Expand Up @@ -260,7 +260,7 @@ def test_run(self):
model = self.build_model()
with model:
start = {'psi': 0.5, 'z': (self.y > 0).astype(int), 'theta': 5}
step_one = pm.Metropolis([model.theta_interval_, model.psi_logodds_])
step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__])
step_two = pm.BinaryMetropolis([model.z])
pm.sample(50, step=[step_one, step_two], start=start)

Expand Down
26 changes: 26 additions & 0 deletions pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pymc3 as pm
from pymc3.distributions import HalfCauchy, Normal
from pymc3.distributions.transforms import Transform
from pymc3 import Potential, Deterministic
from pymc3.theanof import generator
from .helpers import select_by_precision
Expand Down Expand Up @@ -193,3 +194,28 @@ def test_gradient_with_scaling(self):
g1 = grad1(1)
g2 = grad2(1)
np.testing.assert_almost_equal(g1, g2)


class TestTransformName(object):
cases = [
('var', 'var_test__'),
('var_test_', 'var_test__test__')
]
transform_name = 'test'

def test_get_transformed_name(self):
test_transform = Transform()
test_transform.name = self.transform_name
for name, transformed in self.cases:
assert pm.model.get_transformed_name(name, test_transform) == transformed

def test_is_transformed_name(self):
for name, transformed in self.cases:
assert pm.model.is_transformed_name(transformed)
assert not pm.model.is_transformed_name(name)

def test_get_untransformed_name(self):
for name, transformed in self.cases:
assert pm.model.get_untransformed_name(transformed) == name
with pytest.raises(ValueError):
pm.model.get_untransformed_name(name)
6 changes: 3 additions & 3 deletions pymc3/tests/test_models_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_class(cls):

def test_linear_component(self):
vars_to_create = {
'sigma_interval_',
'sigma_interval__',
'y_obs',
'lm_x0',
'lm_Intercept'
Expand All @@ -41,7 +41,7 @@ def test_linear_component(self):
self.data_linear['y'],
name='lm'
) # yields lm_x0, lm_Intercept
sigma = Uniform('sigma', 0, 20) # yields sigma_interval_
sigma = Uniform('sigma', 0, 20) # yields sigma_interval__
Normal('y_obs', mu=lm.y_est, sd=sigma, observed=self.y_linear) # yields y_obs
start = find_MAP(vars=[sigma])
step = Slice(model.vars)
Expand All @@ -68,7 +68,7 @@ def test_linear_component_from_formula(self):
def test_glm(self):
with Model() as model:
vars_to_create = {
'glm_sd_log_',
'glm_sd_log__',
'glm_y',
'glm_x0',
'glm_Intercept'
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_multichain_plots():
model = build_disaster_model()
with model:
# Run sampler
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
step2 = Metropolis([model.switchpoint])
start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50}
ptrace = sample(1000, step=[step1, step2], start=start, njobs=2)
Expand Down
6 changes: 3 additions & 3 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def test_soft_update_empty(self):
test_point = {'a': 3, 'b': 4}
pm.sampling._soft_update(start, test_point)
assert start == test_point

def test_soft_update_transformed(self):
start = {'a': 2}
test_point = {'a_log_': 0}
test_point = {'a_log__': 0}
pm.sampling._soft_update(start, test_point)
assert assert_almost_equal(start['a_log_'], np.log(start['a']))
assert assert_almost_equal(start['a_log__'], np.log(start['a']))


class TestNamedSampling(SeededTest):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,5 +378,5 @@ def test_row_names(self):
step = Metropolis()
trace = pm.sample(100, step=step)
ds = df_summary(trace, batches=3, include_transformed=True)
npt.assert_equal(np.array(['x_interval_', 'x']),
npt.assert_equal(np.array(['x_interval__', 'x']),
ds.index)
2 changes: 1 addition & 1 deletion pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_sample(self):
assert trace.varnames == ['p']
assert len(trace) == 1
trace = app.sample(draws=10, hide_transformed=False)
assert sorted(trace.varnames) == ['p', 'p_logodds_']
assert sorted(trace.varnames) == ['p', 'p_logodds__']
assert len(trace) == 10

def test_sample_node(self):
Expand Down