4
4
from .stats import statfunc , autocov
5
5
from .util import get_default_varnames
6
6
from .backends .base import MultiTrace
7
+ import warnings
7
8
8
9
__all__ = ['geweke' , 'gelman_rubin' , 'effective_n' ]
9
10
@@ -97,7 +98,8 @@ def geweke(x, first=.1, last=.5, intervals=20):
97
98
return np .array (zscores )
98
99
99
100
100
- def gelman_rubin (mtrace , varnames = None , include_transformed = False ):
101
+
102
+ def gelman_rubin (mtrace , var_names = None , include_transformed = False , ** kwargs ):
101
103
R"""Returns estimate of R for a set of traces.
102
104
103
105
The Gelman-Rubin diagnostic tests for lack of convergence by comparing
@@ -141,7 +143,12 @@ def gelman_rubin(mtrace, varnames=None, include_transformed=False):
141
143
----------
142
144
Brooks and Gelman (1998)
143
145
Gelman and Rubin (1992)"""
144
-
146
+ if 'varnames' in kwargs :
147
+ var_names = kwargs ['varnames' ]
148
+ warnings .warn (
149
+ 'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8' ,
150
+ DeprecationWarning
151
+ )
145
152
def rscore (x , num_samples ):
146
153
# Calculate between-chain variance
147
154
B = num_samples * np .var (np .mean (x , axis = 1 ), axis = 0 , ddof = 1 )
@@ -163,20 +170,20 @@ def rscore(x, num_samples):
163
170
'Gelman-Rubin diagnostic requires multiple chains '
164
171
'of the same length.' )
165
172
166
- if varnames is None :
167
- varnames = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
173
+ if var_names is None :
174
+ var_names = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
168
175
169
176
Rhat = {}
170
177
171
- for var in varnames :
178
+ for var in var_names :
172
179
x = np .array (mtrace .get_values (var , combine = False ))
173
180
num_samples = x .shape [1 ]
174
181
Rhat [var ] = rscore (x , num_samples )
175
182
176
183
return Rhat
177
184
178
185
179
- def effective_n (mtrace , varnames = None , include_transformed = False ):
186
+ def effective_n (mtrace , var_names = None , include_transformed = False , ** kwargs ):
180
187
R"""Returns estimate of the effective sample size of a set of traces.
181
188
182
189
Parameters
@@ -211,7 +218,11 @@ def effective_n(mtrace, varnames=None, include_transformed=False):
211
218
References
212
219
----------
213
220
Gelman et al. BDA (2014)"""
214
-
221
+ if 'varnames' in kwargs :
222
+ var_names = kwargs ['varnames' ]
223
+ warnings .warn (
224
+ 'Keyword argument varnames renamed to var_names, and will be removed in pymc3 3.8' ,
225
+ DeprecationWarning )
215
226
def get_neff (x ):
216
227
"""Compute the effective sample size for a 2D array
217
228
"""
@@ -291,12 +302,12 @@ def generate_neff(trace_values):
291
302
'Calculation of effective sample size requires multiple chains '
292
303
'of the same length.' )
293
304
294
- if varnames is None :
295
- varnames = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
305
+ if var_names is None :
306
+ var_names = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
296
307
297
308
n_eff = {}
298
309
299
- for var in varnames :
310
+ for var in var_names :
300
311
n_eff [var ] = generate_neff (mtrace .get_values (var , combine = False ))
301
312
302
313
return n_eff
0 commit comments