Skip to content

Commit d18a0aa

Browse files
author
Chris Fonnesbeck
committed
Merge pull request #959 from pymc-devs/vars_varnames
Consistent use of vars and varnames
2 parents 4c002fd + 4f3b8db commit d18a0aa

File tree

5 files changed

+27
-27
lines changed

5 files changed

+27
-27
lines changed

pymc3/examples/GHME_2013.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def run(n=3000):
108108

109109
# <codecell>
110110

111-
traceplot(trace[100:], vars = [coeff_sd,sd ]);
111+
traceplot(trace[100:], varnames = [coeff_sd,sd ]);
112112

113113
# <codecell>
114114

pymc3/examples/disaster_model_missing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def run(n=1000):
6060

6161
tr = sample(n, tune=500, start=start, step=step)
6262

63-
summary(tr, vars=['disasters_missing'])
63+
summary(tr, varnames=['disasters_missing'])
6464

6565
if __name__ == '__main__':
6666
run()

pymc3/plots.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
__all__ = ['traceplot', 'kdeplot', 'kde2plot', 'forestplot', 'autocorrplot']
77

88

9-
def traceplot(trace, vars=None, figsize=None,
9+
def traceplot(trace, varnames=None, figsize=None,
1010
lines=None, combined=False, grid=True,
1111
alpha=0.35, ax=None):
1212
"""Plot samples histograms and values
@@ -15,7 +15,7 @@ def traceplot(trace, vars=None, figsize=None,
1515
----------
1616
1717
trace : result of MCMC run
18-
vars : list of variable names
18+
varnames : list of variable names
1919
Variables to be plotted, if None all variable are plotted
2020
figsize : figure size tuple
2121
If None, size is (12, num of variables * 2) inch
@@ -38,10 +38,10 @@ def traceplot(trace, vars=None, figsize=None,
3838
3939
"""
4040
import matplotlib.pyplot as plt
41-
if vars is None:
42-
vars = trace.varnames
41+
if varnames is None:
42+
varnames = trace.varnames
4343

44-
n = len(vars)
44+
n = len(varnames)
4545

4646
if figsize is None:
4747
figsize = (12, n*2)
@@ -52,7 +52,7 @@ def traceplot(trace, vars=None, figsize=None,
5252
print('traceplot requires n*2 subplots')
5353
return None
5454

55-
for i, v in enumerate(vars):
55+
for i, v in enumerate(varnames):
5656
for d in trace.get_values(v, combine=combined, squeeze=False):
5757
d = np.squeeze(d)
5858
d = make_2d(d)
@@ -259,7 +259,7 @@ def var_str(name, shape):
259259
return names
260260

261261

262-
def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
262+
def forestplot(trace_obj, varnames=None, alpha=0.05, quartiles=True, rhat=True,
263263
main=None, xtitle=None, xrange=None, ylabels=None,
264264
chain_spacing=0.05, vline=0, gs=None):
265265
""" Forest plot (model summary plot)
@@ -271,7 +271,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
271271
trace_obj: NpTrace or MultiTrace object
272272
Trace(s) from an MCMC sample.
273273
274-
vars: list
274+
varnames: list
275275
List of variables to plot (defaults to None, which results in all
276276
variables plotted).
277277
@@ -339,14 +339,14 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
339339
from .diagnostics import gelman_rubin
340340

341341
R = gelman_rubin(trace_obj)
342-
if vars is not None:
343-
R = {v: R[v] for v in vars}
342+
if varnames is not None:
343+
R = {v: R[v] for v in varnames}
344344
else:
345345
# Can't calculate Gelman-Rubin with a single trace
346346
rhat = False
347347

348-
if vars is None:
349-
vars = trace_obj.varnames
348+
if varnames is None:
349+
varnames = trace_obj.varnames
350350

351351
# Empty list for y-axis labels
352352
labels = []
@@ -370,7 +370,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
370370
for j, chain in enumerate(trace_obj.chains):
371371
# Counter for current variable
372372
var = 1
373-
for varname in vars:
373+
for varname in varnames:
374374
var_quantiles = trace_quantiles[chain][varname]
375375

376376
quants = [var_quantiles[v] for v in qlist]
@@ -533,7 +533,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True,
533533
plt.yticks([-(l + 1) for l in range(len(labels))], "")
534534

535535
i = 1
536-
for varname in vars:
536+
for varname in varnames:
537537

538538
chain = trace_obj.chains[0]
539539
value = trace_obj.get_values(varname, chains=[chain])[0]

pymc3/stats.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5)):
313313
print("Too few elements for quantile calculation")
314314

315315

316-
def df_summary(trace, vars=None, stat_funcs=None, extend=False,
316+
def df_summary(trace, varnames=None, stat_funcs=None, extend=False,
317317
alpha=0.05, batches=100):
318318
"""Create a data frame with summary statistics.
319319
320320
Parameters
321321
----------
322322
trace : MultiTrace instance
323-
vars : list
323+
varnames : list
324324
Names of variables to include in summary
325325
stat_funcs : None or list
326326
A list of functions used to calculate statistics. By default,
@@ -387,8 +387,8 @@ def df_summary(trace, vars=None, stat_funcs=None, extend=False,
387387
mu__0 0.066473 0.000312 0.105039 0.214242
388388
mu__1 0.067513 -0.159097 -0.045637 0.062912
389389
"""
390-
if vars is None:
391-
vars = trace.varnames
390+
if varnames is None:
391+
varnames = trace.varnames
392392

393393
funcs = [lambda x: pd.Series(np.mean(x, 0), name='mean'),
394394
lambda x: pd.Series(np.std(x, 0), name='sd'),
@@ -401,7 +401,7 @@ def df_summary(trace, vars=None, stat_funcs=None, extend=False,
401401
stat_funcs = funcs
402402

403403
var_dfs = []
404-
for var in vars:
404+
for var in varnames:
405405
vals = trace.get_values(var, combine=True)
406406
flat_vals = vals.reshape(vals.shape[0], -1)
407407
var_df = pd.concat([f(flat_vals) for f in stat_funcs], axis=1)
@@ -416,7 +416,7 @@ def _hpd_df(x, alpha):
416416
return pd.DataFrame(hpd(x, alpha), columns=cnames)
417417

418418

419-
def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3,
419+
def summary(trace, varnames=None, alpha=0.05, start=0, batches=100, roundto=3,
420420
to_file=None):
421421
"""
422422
Generate a pretty-printed summary of the node.
@@ -425,7 +425,7 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3,
425425
trace : Trace object
426426
Trace containing MCMC sample
427427
428-
vars : list of strings
428+
varnames : list of strings
429429
List of variables to summarize. Defaults to None, which results
430430
in all variables summarized.
431431
@@ -448,8 +448,8 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3,
448448
File to write results to. If not given, print to stdout.
449449
450450
"""
451-
if vars is None:
452-
vars = trace.varnames
451+
if varnames is None:
452+
varnames = trace.varnames
453453

454454
stat_summ = _StatSummary(roundto, batches, alpha)
455455
pq_summ = _PosteriorQuantileSummary(roundto, alpha)
@@ -459,7 +459,7 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3,
459459
else:
460460
fh = open(to_file, mode='w')
461461

462-
for var in vars:
462+
for var in varnames:
463463
# Extract sampled values
464464
sample = trace.get_values(var, burn=start, combine=True)
465465

pymc3/tests/test_plots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_multichain_plots():
5454
start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50}
5555
ptrace = sample(1000, [step1, step2], start, njobs=2)
5656

57-
forestplot(ptrace, vars=['early_mean', 'late_mean'])
57+
forestplot(ptrace, varnames=['early_mean', 'late_mean'])
5858

5959
autocorrplot(ptrace, varnames=['switchpoint'])
6060

0 commit comments

Comments
 (0)