Skip to content

Commit b53a746

Browse files
committed
Use consistent transform names everywhere
1 parent 7c5ca8d commit b53a746

12 files changed

+105
-31
lines changed

pymc3/model.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,55 @@
2323
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')
2424

2525

26+
def get_transformed_name(name, transform):
27+
"""
28+
Consistent way of transforming names
29+
30+
Parameters
31+
----------
32+
name : str
33+
Name to transform
34+
transform : object
35+
Should be a subclass of `transforms.Transform`
36+
37+
Returns:
38+
A string to use for the transformed variable
39+
"""
40+
return "{}_{}__".format(name, transform.name)
41+
42+
43+
def is_transformed_name(name):
44+
"""
45+
Quickly check if a name was transformed with `get_transormed_name`
46+
47+
Parameters
48+
----------
49+
name : str
50+
Name to check
51+
52+
Returns:
53+
Boolean, whether the string could have been produced by `get_transormed_name`
54+
"""
55+
return name.endswith('__') and name.count('_') >= 3
56+
57+
58+
def get_untransformed_name(name):
59+
"""
60+
Undo transformation in `get_transformed_name`. Throws ValueError if name wasn't transformed
61+
62+
Parameters
63+
----------
64+
name : str
65+
Name to untransform
66+
67+
Returns:
68+
String with untransformed version of the name.
69+
"""
70+
if not is_transformed_name(name):
71+
raise ValueError(u'{} does not appear to be a transformed name'.format(name))
72+
return '_'.join(name.split('_')[:-3])
73+
74+
2675
class InstanceMethod(object):
2776
"""Class for hiding references to instance methods so they can be pickled.
2877
@@ -516,7 +565,7 @@ def Var(self, name, dist, data=None, total_size=None):
516565
' and added transformed {orig_name} to model.'.format(
517566
transform=dist.transform.name,
518567
name=name,
519-
orig_name='{}_{}_'.format(name, dist.transform.name)))
568+
orig_name=get_transformed_name(name, dist.transform)))
520569
self.deterministics.append(var)
521570
return var
522571
elif isinstance(data, dict):
@@ -989,13 +1038,13 @@ def __init__(self, type=None, owner=None, index=None, name=None,
9891038
if type is None:
9901039
type = distribution.type
9911040
super(TransformedRV, self).__init__(type, owner, index, name)
992-
1041+
9931042
self.transformation = transform
9941043

9951044
if distribution is not None:
9961045
self.model = model
9971046

998-
transformed_name = "{}_{}_".format(name, transform.name)
1047+
transformed_name = get_transformed_name(name, transform)
9991048

10001049
self.transformed = model.Var(
10011050
transformed_name, transform.apply(distribution), total_size=total_size)

pymc3/sampling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pymc3 as pm
88
from .backends.base import merge_traces, BaseTrace, MultiTrace
99
from .backends.ndarray import NDArray
10-
from .model import modelcontext, Point
10+
from .model import modelcontext, Point, is_transformed_name, get_untransformed_name
1111
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
1212
BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
1313
Slice, CompoundStep)
@@ -462,14 +462,13 @@ def _update_start_vals(a, b, model):
462462
"""Update a with b, without overwriting existing keys. Values specified for
463463
transformed variables on the original scale are also transformed and inserted.
464464
"""
465-
466465
for name in a:
467466
for tname in b:
468-
if tname.startswith(name) and tname!=name:
469-
transform_func = [d.transformation for d in model.deterministics if d.name==name]
467+
if is_transformed_name(tname) and get_untransformed_name(tname) == name:
468+
transform_func = [d.transformation for d in model.deterministics if d.name == name]
470469
if transform_func:
471470
b[tname] = transform_func[0].forward(a[name]).eval()
472-
471+
473472
a.update({k: v for k, v in b.items() if k not in a})
474473

475474
def sample_ppc(trace, samples=None, model=None, vars=None, size=None,

pymc3/tests/test_advi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,4 @@ def test_sample_vp(self):
264264
trace = sample_vp(v_params, draws=1, hide_transformed=True)
265265
assert trace.varnames == ['p']
266266
trace = sample_vp(v_params, draws=1, hide_transformed=False)
267-
assert sorted(trace.varnames) == ['p', 'p_logodds_']
267+
assert sorted(trace.varnames) == ['p', 'p_logodds__']

pymc3/tests/test_diagnostics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def get_ptrace(self, n_samples):
1919
model = build_disaster_model()
2020
with model:
2121
# Run sampler
22-
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
22+
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
2323
step2 = Metropolis([model.switchpoint])
2424
start = {'early_mean': 7., 'late_mean': 5., 'switchpoint': 10}
25-
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
25+
ptrace = sample(n_samples, step=[step1, step2], start=start, njobs=2,
2626
progressbar=False, random_seed=[20090425, 19700903])
2727
return ptrace
2828

@@ -91,7 +91,7 @@ def get_switchpoint(self, n_samples):
9191
model = build_disaster_model()
9292
with model:
9393
# Run sampler
94-
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
94+
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
9595
step2 = Metropolis([model.switchpoint])
9696
trace = sample(n_samples, step=[step1, step2], progressbar=False, random_seed=1)
9797
return trace['switchpoint']

pymc3/tests/test_dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,5 @@ def test_multinomial_bound():
111111
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
112112
MultinomialB('x', n, p_b, observed=x)
113113

114-
assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}),
115-
modelB.logp({'p_stickbreaking_': [0]}))
114+
assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),
115+
modelB.logp({'p_stickbreaking__': [0]}))

pymc3/tests/test_examples.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ def build_model(self):
7979
def too_slow(self):
8080
model = self.build_model()
8181
start = {'groupmean': self.obs_means.mean(),
82-
'groupsd_interval_': 0,
83-
'sd_interval_': 0,
82+
'groupsd_interval__': 0,
83+
'sd_interval__': 0,
8484
'means': self.obs_means,
8585
'floor_m': 0.,
8686
}
8787
with model:
8888
start = pm.find_MAP(start=start,
89-
vars=[model['groupmean'], model['sd_interval_'], model['floor_m']])
89+
vars=[model['groupmean'], model['sd_interval__'], model['floor_m']])
9090
step = pm.NUTS(model.vars, scaling=start)
9191
pm.sample(50, step=step, start=start)
9292

@@ -117,8 +117,8 @@ def too_slow(self):
117117
with model:
118118
start = pm.Point({
119119
'groupmean': self.obs_means.mean(),
120-
'groupsd_interval_': 0,
121-
'sd_interval_': 0,
120+
'groupsd_interval__': 0,
121+
'sd_interval__': 0,
122122
'means': np.array(self.obs_means),
123123
'u_m': np.array([.72]),
124124
'floor_m': 0.,
@@ -168,7 +168,7 @@ def test_disaster_model(self):
168168
# Initial values for stochastic nodes
169169
start = {'early_mean': 2., 'late_mean': 3.}
170170
# Use slice sampler for means (other varibles auto-selected)
171-
step = pm.Slice([model.early_mean_log_, model.late_mean_log_])
171+
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
172172
tr = pm.sample(500, tune=50, start=start, step=step)
173173
pm.summary(tr)
174174

@@ -178,7 +178,7 @@ def test_disaster_model_missing(self):
178178
# Initial values for stochastic nodes
179179
start = {'early_mean': 2., 'late_mean': 3.}
180180
# Use slice sampler for means (other varibles auto-selected)
181-
step = pm.Slice([model.early_mean_log_, model.late_mean_log_])
181+
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
182182
tr = pm.sample(500, tune=50, start=start, step=step)
183183
pm.summary(tr)
184184

@@ -260,7 +260,7 @@ def test_run(self):
260260
model = self.build_model()
261261
with model:
262262
start = {'psi': 0.5, 'z': (self.y > 0).astype(int), 'theta': 5}
263-
step_one = pm.Metropolis([model.theta_interval_, model.psi_logodds_])
263+
step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__])
264264
step_two = pm.BinaryMetropolis([model.z])
265265
pm.sample(50, step=[step_one, step_two], start=start)
266266

pymc3/tests/test_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pymc3 as pm
55
from pymc3.distributions import HalfCauchy, Normal
6+
from pymc3.distributions.transforms import Transform
67
from pymc3 import Potential, Deterministic
78
from pymc3.theanof import generator
89
from .helpers import select_by_precision
@@ -193,3 +194,28 @@ def test_gradient_with_scaling(self):
193194
g1 = grad1(1)
194195
g2 = grad2(1)
195196
np.testing.assert_almost_equal(g1, g2)
197+
198+
199+
class TestTransformName(object):
200+
cases = [
201+
('var', 'var_test__'),
202+
('var_test_', 'var_test__test__')
203+
]
204+
transform_name = 'test'
205+
206+
def test_get_transformed_name(self):
207+
test_transform = Transform()
208+
test_transform.name = self.transform_name
209+
for name, transformed in self.cases:
210+
assert pm.model.get_transformed_name(name, test_transform) == transformed
211+
212+
def test_is_transformed_name(self):
213+
for name, transformed in self.cases:
214+
assert pm.model.is_transformed_name(transformed)
215+
assert not pm.model.is_transformed_name(name)
216+
217+
def test_get_untransformed_name(self):
218+
for name, transformed in self.cases:
219+
assert pm.model.get_untransformed_name(transformed) == name
220+
with pytest.raises(ValueError):
221+
pm.model.get_untransformed_name(name)

pymc3/tests/test_models_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def setup_class(cls):
3030

3131
def test_linear_component(self):
3232
vars_to_create = {
33-
'sigma_interval_',
33+
'sigma_interval__',
3434
'y_obs',
3535
'lm_x0',
3636
'lm_Intercept'
@@ -41,7 +41,7 @@ def test_linear_component(self):
4141
self.data_linear['y'],
4242
name='lm'
4343
) # yields lm_x0, lm_Intercept
44-
sigma = Uniform('sigma', 0, 20) # yields sigma_interval_
44+
sigma = Uniform('sigma', 0, 20) # yields sigma_interval__
4545
Normal('y_obs', mu=lm.y_est, sd=sigma, observed=self.y_linear) # yields y_obs
4646
start = find_MAP(vars=[sigma])
4747
step = Slice(model.vars)
@@ -68,7 +68,7 @@ def test_linear_component_from_formula(self):
6868
def test_glm(self):
6969
with Model() as model:
7070
vars_to_create = {
71-
'glm_sd_log_',
71+
'glm_sd_log__',
7272
'glm_y',
7373
'glm_x0',
7474
'glm_Intercept'

pymc3/tests/test_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_multichain_plots():
5757
model = build_disaster_model()
5858
with model:
5959
# Run sampler
60-
step1 = Slice([model.early_mean_log_, model.late_mean_log_])
60+
step1 = Slice([model.early_mean_log__, model.late_mean_log__])
6161
step2 = Metropolis([model.switchpoint])
6262
start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50}
6363
ptrace = sample(1000, step=[step1, step2], start=start, njobs=2)

pymc3/tests/test_sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ def test_soft_update_empty(self):
119119
test_point = {'a': 3, 'b': 4}
120120
pm.sampling._soft_update(start, test_point)
121121
assert start == test_point
122-
122+
123123
def test_soft_update_transformed(self):
124124
start = {'a': 2}
125-
test_point = {'a_log_': 0}
125+
test_point = {'a_log__': 0}
126126
pm.sampling._soft_update(start, test_point)
127-
assert assert_almost_equal(start['a_log_'], np.log(start['a']))
127+
assert assert_almost_equal(start['a_log__'], np.log(start['a']))
128128

129129

130130
class TestNamedSampling(SeededTest):

pymc3/tests/test_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,5 @@ def test_row_names(self):
378378
step = Metropolis()
379379
trace = pm.sample(100, step=step)
380380
ds = df_summary(trace, batches=3, include_transformed=True)
381-
npt.assert_equal(np.array(['x_interval_', 'x']),
381+
npt.assert_equal(np.array(['x_interval__', 'x']),
382382
ds.index)

pymc3/tests/test_variational_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def test_sample(self):
116116
assert trace.varnames == ['p']
117117
assert len(trace) == 1
118118
trace = app.sample(draws=10, hide_transformed=False)
119-
assert sorted(trace.varnames) == ['p', 'p_logodds_']
119+
assert sorted(trace.varnames) == ['p', 'p_logodds__']
120120
assert len(trace) == 10
121121

122122
def test_sample_node(self):

0 commit comments

Comments
 (0)