Skip to content

Commit 33c8bff

Browse files
author
Junpeng Lao
authored
additional kwarg for pm.diagnostics (#2537)
* additional kwarg for pm.diagnostics add varnames and include_transformed * fix typo
1 parent 4e725eb commit 33c8bff

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

pymc3/diagnostics.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from .stats import statfunc
5+
from .util import get_default_varnames
56

67
__all__ = ['geweke', 'gelman_rubin', 'effective_n']
78

@@ -95,7 +96,7 @@ def geweke(x, first=.1, last=.5, intervals=20):
9596
return np.array(zscores)
9697

9798

98-
def gelman_rubin(mtrace):
99+
def gelman_rubin(mtrace, varnames=None, include_transformed=False):
99100
R"""Returns estimate of R for a set of traces.
100101
101102
The Gelman-Rubin diagnostic tests for lack of convergence by comparing
@@ -110,6 +111,11 @@ def gelman_rubin(mtrace):
110111
mtrace : MultiTrace
111112
A MultiTrace object containing parallel traces (minimum 2)
112113
of one or more stochastic parameters.
114+
varnames : list
115+
Names of variables to include in the rhat report
116+
include_transformed : bool
117+
Flag for reporting automatically transformed variables in addition
118+
to original variables (defaults to False).
113119
114120
Returns
115121
-------
@@ -140,8 +146,11 @@ def gelman_rubin(mtrace):
140146
'Gelman-Rubin diagnostic requires multiple chains '
141147
'of the same length.')
142148

149+
if varnames is None:
150+
varnames = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
151+
143152
Rhat = {}
144-
for var in mtrace.varnames:
153+
for var in varnames:
145154
x = np.array(mtrace.get_values(var, combine=False))
146155
num_samples = x.shape[1]
147156

@@ -159,14 +168,19 @@ def gelman_rubin(mtrace):
159168
return Rhat
160169

161170

162-
def effective_n(mtrace):
171+
def effective_n(mtrace, varnames=None, include_transformed=False):
163172
R"""Returns estimate of the effective sample size of a set of traces.
164173
165174
Parameters
166175
----------
167176
mtrace : MultiTrace
168-
A MultiTrace object containing parallel traces (minimum 2)
169-
of one or more stochastic parameters.
177+
A MultiTrace object containing parallel traces (minimum 2)
178+
of one or more stochastic parameters.
179+
varnames : list
180+
Names of variables to include in the effective_n report
181+
include_transformed : bool
182+
Flag for reporting automatically transformed variables in addition
183+
to original variables (defaults to False).
170184
171185
Returns
172186
-------
@@ -192,6 +206,9 @@ def effective_n(mtrace):
192206
'Calculation of effective sample size requires multiple chains '
193207
'of the same length.')
194208

209+
if varnames is None:
210+
varnames = get_default_varnames(mtrace.varnames, include_transformed=include_transformed)
211+
195212
def get_vhat(x):
196213
# number of chains is last dim (-1)
197214
# chain samples are second to last dim (-2)
@@ -234,7 +251,7 @@ def get_neff(x, Vhat):
234251
int(num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())))
235252

236253
n_eff = {}
237-
for var in mtrace.varnames:
254+
for var in varnames:
238255
x = np.array(mtrace.get_values(var, combine=False))
239256

240257
# make sure to handle scalars correctly - add extra dim if needed

0 commit comments

Comments
 (0)