Skip to content

Commit bfe63c2

Browse files
denadai2twiecki
authored andcommitted
Fixed inconsistency of advi and opvi include_transformed param (#2117)
* Fixed inconsistency of advi and opvi include_transformed param * Fixed other inconsistencies on the ADVI OPVI API
1 parent a97206d commit bfe63c2

File tree

8 files changed

+62
-97
lines changed

8 files changed

+62
-97
lines changed

docs/source/notebooks/GLM-model-selection.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@
313313
" _ = plt.legend()\n",
314314
" _ = ax1d.set_xlim(xlims)\n",
315315
" _ = sns.regplot(x='x', y='y', data=rawdata, fit_reg=False\n",
316-
" ,scatter_kws={'alpha':0.7,'s':100, 'lw':2,'edgecolor':'w'}, ax=ax1d)\n"
316+
" ,scatter_kws={'alpha':0.7,'s':100, 'lw':2,'edgecolor':'w'}, ax=ax1d)"
317317
]
318318
},
319319
{
@@ -1054,7 +1054,7 @@
10541054
}
10551055
],
10561056
"source": [
1057-
"dftrc_lin = pm.trace_to_dataframe(traces_lin['k1'], hide_transformed_vars=False)\n",
1057+
"dftrc_lin = pm.trace_to_dataframe(traces_lin['k1'], include_transformed=True)\n",
10581058
"trc_lin_logp = dftrc_lin.apply(lambda x: models_lin['k1'].logp(x.to_dict()), axis=1)\n",
10591059
"mean_deviance = -2 * trc_lin_logp.mean(0)\n",
10601060
"mean_deviance"
@@ -1406,7 +1406,7 @@
14061406
"language_info": {
14071407
"codemirror_mode": {
14081408
"name": "ipython",
1409-
"version": 3
1409+
"version": 3.0
14101410
},
14111411
"file_extension": ".py",
14121412
"mimetype": "text/x-python",
@@ -1420,14 +1420,14 @@
14201420
"87b986ac3e5a43ec859cf10e013f2955": {
14211421
"views": [
14221422
{
1423-
"cell_index": 9
1423+
"cell_index": 9.0
14241424
}
14251425
]
14261426
},
14271427
"f1f05f8da738419e8e2c54ee1809c61c": {
14281428
"views": [
14291429
{
1430-
"cell_index": 47
1430+
"cell_index": 47.0
14311431
}
14321432
]
14331433
}
@@ -1437,4 +1437,4 @@
14371437
},
14381438
"nbformat": 4,
14391439
"nbformat_minor": 0
1440-
}
1440+
}

pymc3/backends/text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,5 @@ def dump(name, trace, chains=None):
202202
for chain in chains:
203203
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
204204
df = ttab.trace_to_dataframe(
205-
trace, chains=chain, hide_transformed_vars=False)
205+
trace, chains=chain, include_transformed=True)
206206
df.to_csv(filename, index=False)

pymc3/backends/tracetab.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ['trace_to_dataframe']
1010

1111

12-
def trace_to_dataframe(trace, chains=None, varnames=None, hide_transformed_vars=True):
12+
def trace_to_dataframe(trace, chains=None, varnames=None, include_transformed=False):
1313
"""Convert trace to Pandas DataFrame.
1414
1515
Parameters
@@ -21,15 +21,15 @@ def trace_to_dataframe(trace, chains=None, varnames=None, hide_transformed_vars=
2121
varnames : list of variable names
2222
Variables to be included in the DataFrame, if None all variable are
2323
included.
24-
hide_transformed_vars: boolean
25-
If true transformed variables will not be included in the resulting
24+
include_transformed: boolean
25+
If true transformed variables will be included in the resulting
2626
DataFrame.
2727
"""
2828
var_shapes = trace._straces[0].var_shapes
2929

3030
if varnames is None:
3131
varnames = get_default_varnames(var_shapes.keys(),
32-
include_transformed=not hide_transformed_vars)
32+
include_transformed=include_transformed)
3333

3434
flat_names = {v: create_flat_names(v, var_shapes[v]) for v in varnames}
3535

@@ -60,7 +60,7 @@ def create_flat_names(varname, shape):
6060

6161

6262
def _create_shape(flat_names):
63-
"Determine shape from `create_flat_names` output."
63+
"""Determine shape from `create_flat_names` output."""
6464
try:
6565
_, shape_str = flat_names[-1].rsplit('__', 1)
6666
except ValueError:

pymc3/tests/test_advi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_sample_vp(self):
261261
p = pm.Beta('p', alpha=1, beta=1)
262262
pm.Binomial('xs', n=1, p=p, observed=xs)
263263
v_params = advi(n=1000)
264-
trace = sample_vp(v_params, draws=1, hide_transformed=True)
264+
trace = sample_vp(v_params, draws=1, include_transformed=False)
265265
assert trace.varnames == ['p']
266-
trace = sample_vp(v_params, draws=1, hide_transformed=False)
266+
trace = sample_vp(v_params, draws=1, include_transformed=True)
267267
assert sorted(trace.varnames) == ['p', 'p_logodds__']

pymc3/tests/test_variational_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def test_sample(self):
112112
p = pm.Beta('p', alpha=1, beta=1)
113113
pm.Binomial('xs', n=1, p=p, observed=xs)
114114
app = self.inference().approx
115-
trace = app.sample(draws=1, hide_transformed=True)
115+
trace = app.sample(draws=1, include_transformed=False)
116116
assert trace.varnames == ['p']
117117
assert len(trace) == 1
118-
trace = app.sample(draws=10, hide_transformed=False)
118+
trace = app.sample(draws=10, include_transformed=True)
119119
assert sorted(trace.varnames) == ['p', 'p_logodds__']
120120
assert len(trace) == 10
121121

pymc3/variational/advi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def optimizer(loss, param):
340340

341341
def sample_vp(
342342
vparams, draws=1000, model=None, local_RVs=None, random_seed=None,
343-
hide_transformed=True, progressbar=True):
343+
include_transformed=False, progressbar=True):
344344
"""Draw samples from variational posterior.
345345
346346
Parameters
@@ -353,8 +353,8 @@ def sample_vp(
353353
Probabilistic model.
354354
random_seed : int or None
355355
Seed of random number generator. None to use current seed.
356-
hide_transformed : bool
357-
If False, transformed variables are also sampled. Default is True.
356+
include_transformed : bool
357+
If True, transformed variables are also sampled. Default is False.
358358
359359
Returns
360360
-------
@@ -411,7 +411,7 @@ def rvs(x):
411411

412412
# Random variables which will be sampled
413413
vars_sampled = pm.util.get_default_varnames(model.unobserved_RVs,
414-
include_transformed=not hide_transformed)
414+
include_transformed=include_transformed)
415415

416416
varnames = [str(var) for var in model.unobserved_RVs]
417417
trace = pm.sampling.NDArray(model=model, vars=vars_sampled)

pymc3/variational/approximations.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818

1919
class MeanField(Approximation):
20-
"""
21-
Mean Field approximation to the posterior where spherical Gaussian family
20+
"""Mean Field approximation to the posterior where spherical Gaussian family
2221
is fitted to minimize KL divergence from True posterior. It is assumed
2322
that latent space variables are uncorrelated that is the main drawback
2423
of the method
@@ -92,8 +91,7 @@ def random_global(self, size=None, no_rand=False):
9291

9392

9493
class FullRank(Approximation):
95-
"""
96-
Full Rank approximation to the posterior where Multivariate Gaussian family
94+
"""Full Rank approximation to the posterior where Multivariate Gaussian family
9795
is fitted to minimize KL divergence from True posterior. In contrast to
9896
MeanField approach correlations between variables are taken in account. The
9997
main drawback of the method is computational cost.
@@ -175,8 +173,7 @@ def create_shared_params(self, **kwargs):
175173
}
176174

177175
def log_q_W_global(self, z):
178-
"""
179-
log_q_W samples over q for global vars
176+
"""log_q_W samples over q for global vars
180177
"""
181178
mu = self.scale_grad(self.mean)
182179
L = self.scale_grad(self.L)
@@ -197,8 +194,7 @@ def random_global(self, size=None, no_rand=False):
197194

198195
@classmethod
199196
def from_mean_field(cls, mean_field, gpu_compat=False):
200-
"""
201-
Construct FullRank from MeanField approximation
197+
"""Construct FullRank from MeanField approximation
202198
203199
Parameters
204200
----------
@@ -233,8 +229,7 @@ def from_mean_field(cls, mean_field, gpu_compat=False):
233229

234230

235231
class Empirical(Approximation):
236-
"""
237-
Builds Approximation instance from a given trace,
232+
"""Builds Approximation instance from a given trace,
238233
it has the same interface as variational approximation
239234
240235
Parameters
@@ -309,16 +304,14 @@ def random_global(self, size=None, no_rand=False):
309304

310305
@property
311306
def histogram(self):
312-
"""
313-
Shortcut to flattened Trace
307+
"""Shortcut to flattened Trace
314308
"""
315309
return self.shared_params
316310

317311
@property
318312
@memoize
319313
def histogram_logp(self):
320-
"""
321-
Symbolic logp for every point in trace
314+
"""Symbolic logp for every point in trace
322315
"""
323316
node = self.to_flat_input(self.model.logpt)
324317

@@ -341,8 +334,7 @@ def cov(self):
341334

342335
@classmethod
343336
def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, seed=None):
344-
"""
345-
Initialize Histogram with random noise
337+
"""Initialize Histogram with random noise
346338
347339
Parameters
348340
----------
@@ -371,17 +363,16 @@ def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, see
371363
return hist
372364

373365

374-
def sample_approx(approx, draws=100, hide_transformed=False):
375-
"""
376-
Draw samples from variational posterior.
366+
def sample_approx(approx, draws=100, include_transformed=True):
367+
"""Draw samples from variational posterior.
377368
378369
Parameters
379370
----------
380371
approx : Approximation
381372
draws : int
382373
Number of random samples.
383-
hide_transformed : bool
384-
If False, transformed variables are also sampled. Default is True.
374+
include_transformed : bool
375+
If True, transformed variables are also sampled. Default is True.
385376
386377
Returns
387378
-------
@@ -390,4 +381,4 @@ def sample_approx(approx, draws=100, hide_transformed=False):
390381
"""
391382
if not isinstance(approx, Approximation):
392383
raise TypeError('Need Approximation instance, got %r' % approx)
393-
return approx.sample(draws=draws, hide_transformed=hide_transformed)
384+
return approx.sample(draws=draws, include_transformed=include_transformed)

0 commit comments

Comments
 (0)