Skip to content

Commit ea1b7a4

Browse files
chang111junpenglao
authored andcommitted
Renew feature (#3416)
* to fix issue 3412# * to fix issue 3412# * to fix the issue 3412# * to fix issue 3412 * fix the issue 3412 * to fix issue 3412 * to fix the issue 3412# * 3412# * to fix issue 3412# * changes suggested by junpenglao * minor
1 parent ffbe432 commit ea1b7a4

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

pymc3/diagnostics.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .stats import statfunc, autocov
55
from .util import get_default_varnames
66
from .backends.base import MultiTrace
7+
import warnings
78

89
__all__ = ['geweke', 'gelman_rubin', 'effective_n']
910

@@ -97,7 +98,8 @@ def geweke(x, first=.1, last=.5, intervals=20):
9798
return np.array(zscores)
9899

99100

100-
def gelman_rubin(mtrace, varnames=None, include_transformed=False):
101+
102+
def gelman_rubin(mtrace, var_names=None, include_transformed=False, **kwargs):
101103
R"""Returns estimate of R for a set of traces.
102104
103105
The Gelman-Rubin diagnostic tests for lack of convergence by comparing
@@ -141,7 +143,12 @@ def gelman_rubin(mtrace, varnames=None, include_transformed=False):
141143
----------
142144
Brooks and Gelman (1998)
143145
Gelman and Rubin (1992)"""
144-
146+
if 'varnames' in kwargs:
147+
var_names = kwargs['varnames']
148+
warnings.warn(
149+
'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8',
150+
DeprecationWarning
151+
)
145152
def rscore(x, num_samples):
146153
# Calculate between-chain variance
147154
B = num_samples * np.var(np.mean(x, axis=1), axis=0, ddof=1)
@@ -163,20 +170,20 @@ def rscore(x, num_samples):
163170
'Gelman-Rubin diagnostic requires multiple chains '
164171
'of the same length.')
165172

166-
if varnames is None:
167-
varnames = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
173+
if var_names is None:
174+
var_names = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
168175

169176
Rhat = {}
170177

171-
for var in varnames:
178+
for var in var_names:
172179
x = np.array(mtrace.get_values(var, combine=False))
173180
num_samples = x.shape[1]
174181
Rhat[var] = rscore(x, num_samples)
175182

176183
return Rhat
177184

178185

179-
def effective_n(mtrace, varnames=None, include_transformed=False):
186+
def effective_n(mtrace, var_names=None, include_transformed=False, **kwargs):
180187
R"""Returns estimate of the effective sample size of a set of traces.
181188
182189
Parameters
@@ -211,7 +218,11 @@ def effective_n(mtrace, varnames=None, include_transformed=False):
211218
References
212219
----------
213220
Gelman et al. BDA (2014)"""
214-
221+
if 'varnames' in kwargs:
222+
var_names = kwargs['varnames']
223+
warnings.warn(
224+
'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8',
225+
DeprecationWarning)
215226
def get_neff(x):
216227
"""Compute the effective sample size for a 2D array
217228
"""
@@ -291,12 +302,12 @@ def generate_neff(trace_values):
291302
'Calculation of effective sample size requires multiple chains '
292303
'of the same length.')
293304

294-
if varnames is None:
295-
varnames = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
305+
if var_names is None:
306+
var_names = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
296307

297308
n_eff = {}
298309

299-
for var in varnames:
310+
for var in var_names:
300311
n_eff[var] = generate_neff(mtrace.get_values(var, combine=False))
301312

302313
return n_eff

pymc3/stats.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pymc3 as pm
1717
from pymc3.theanof import floatX
1818

19+
1920
if pkg_resources.get_distribution('scipy').version < '1.0.0':
2021
from scipy.misc import logsumexp
2122
else:
@@ -850,9 +851,11 @@ def dict2pd(statdict, labelname):
850851
statpd = statpd.rename(labelname)
851852
return statpd
852853

853-
def summary(trace, varnames=None, transform=lambda x: x, stat_funcs=None,
854+
855+
856+
def summary(trace, var_names=None, transform=lambda x: x, stat_funcs=None,
854857
extend=False, include_transformed=False,
855-
alpha=0.05, start=0, batches=None):
858+
alpha=0.05, start=0, batches=None, **kwargs):
856859
R"""Create a data frame with summary statistics.
857860
858861
Parameters
@@ -936,10 +939,16 @@ def summary(trace, varnames=None, transform=lambda x: x, stat_funcs=None,
936939
mu__0 0.066473 0.000312 0.105039 0.214242
937940
mu__1 0.067513 -0.159097 -0.045637 0.062912
938941
"""
942+
if 'varnames' in kwargs:
943+
var_names = kwargs['varnames']
944+
warnings.warn(
945+
'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8',
946+
DeprecationWarning
947+
)
939948
from .backends import tracetab as ttab
940949

941-
if varnames is None:
942-
varnames = get_default_varnames(trace.varnames,
950+
if var_names is None:
951+
var_names = get_default_varnames(trace.varnames,
943952
include_transformed=include_transformed)
944953

945954
if batches is None:
@@ -957,7 +966,7 @@ def summary(trace, varnames=None, transform=lambda x: x, stat_funcs=None,
957966
funcs = stat_funcs
958967

959968
var_dfs = []
960-
for var in varnames:
969+
for var in var_names:
961970
vals = transform(trace.get_values(var, burn=start, combine=True))
962971
flat_vals = vals.reshape(vals.shape[0], -1)
963972
var_df = pd.concat([f(flat_vals) for f in funcs], axis=1)
@@ -971,11 +980,11 @@ def summary(trace, varnames=None, transform=lambda x: x, stat_funcs=None,
971980
return dforg
972981
else:
973982
n_eff = pm.effective_n(trace,
974-
varnames=varnames,
983+
varnames=var_names,
975984
include_transformed=include_transformed)
976985
n_eff_pd = dict2pd(n_eff, 'n_eff')
977986
rhat = pm.gelman_rubin(trace,
978-
varnames=varnames,
987+
varnames=var_names,
979988
include_transformed=include_transformed)
980989
rhat_pd = dict2pd(rhat, 'Rhat')
981990
return pd.concat([dforg, n_eff_pd, rhat_pd],

0 commit comments

Comments
 (0)