2
2
3
3
import numpy as np
4
4
from .stats import statfunc
5
+ from .util import get_default_varnames
5
6
6
7
__all__ = ['geweke' , 'gelman_rubin' , 'effective_n' ]
7
8
@@ -95,7 +96,7 @@ def geweke(x, first=.1, last=.5, intervals=20):
95
96
return np .array (zscores )
96
97
97
98
98
- def gelman_rubin (mtrace ):
99
+ def gelman_rubin (mtrace , varnames = None , include_transformed = False ):
99
100
R"""Returns estimate of R for a set of traces.
100
101
101
102
The Gelman-Rubin diagnostic tests for lack of convergence by comparing
@@ -110,6 +111,11 @@ def gelman_rubin(mtrace):
110
111
mtrace : MultiTrace
111
112
A MultiTrace object containing parallel traces (minimum 2)
112
113
of one or more stochastic parameters.
114
+ varnames : list
115
+ Names of variables to include in the rhat report
116
+ include_transformed : bool
117
+ Flag for reporting automatically transformed variables in addition
118
+ to original variables (defaults to False).
113
119
114
120
Returns
115
121
-------
@@ -140,8 +146,11 @@ def gelman_rubin(mtrace):
140
146
'Gelman-Rubin diagnostic requires multiple chains '
141
147
'of the same length.' )
142
148
149
+ if varnames is None :
150
+ varnames = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
151
+
143
152
Rhat = {}
144
- for var in mtrace . varnames :
153
+ for var in varnames :
145
154
x = np .array (mtrace .get_values (var , combine = False ))
146
155
num_samples = x .shape [1 ]
147
156
@@ -159,14 +168,19 @@ def gelman_rubin(mtrace):
159
168
return Rhat
160
169
161
170
162
- def effective_n (mtrace ):
171
+ def effective_n (mtrace , varnames = None , include_transformed = False ):
163
172
R"""Returns estimate of the effective sample size of a set of traces.
164
173
165
174
Parameters
166
175
----------
167
176
mtrace : MultiTrace
168
- A MultiTrace object containing parallel traces (minimum 2)
169
- of one or more stochastic parameters.
177
+ A MultiTrace object containing parallel traces (minimum 2)
178
+ of one or more stochastic parameters.
179
+ varnames : list
180
+ Names of variables to include in the effective_n report
181
+ include_transformed : bool
182
+ Flag for reporting automatically transformed variables in addition
183
+ to original variables (defaults to False).
170
184
171
185
Returns
172
186
-------
@@ -192,6 +206,9 @@ def effective_n(mtrace):
192
206
'Calculation of effective sample size requires multiple chains '
193
207
'of the same length.' )
194
208
209
+ if varnames is None :
210
+ varnames = get_default_varnames (mtrace .varnames , include_transformed = include_transformed )
211
+
195
212
def get_vhat (x ):
196
213
# number of chains is last dim (-1)
197
214
# chain samples are second to last dim (-2)
@@ -234,7 +251,7 @@ def get_neff(x, Vhat):
234
251
int (num_chains * num_samples / (1. + 2 * rho [1 :t - 1 ].sum ())))
235
252
236
253
n_eff = {}
237
- for var in mtrace . varnames :
254
+ for var in varnames :
238
255
x = np .array (mtrace .get_values (var , combine = False ))
239
256
240
257
# make sure to handle scalars correctly - add extra dim if needed
0 commit comments