Skip to content

Commit 3f3ff78

Browse files
authored
Merge pull request #3608 from aloctavodia/stats
Use diagnostics and stats from ArviZ
2 parents 5746788 + 521174c commit 3f3ff78

File tree

7 files changed

+123
-2002
lines changed

7 files changed

+123
-2002
lines changed

pymc3/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from .exceptions import *
2020
from . import sampling
2121

22-
from .diagnostics import *
2322
from .backends.tracetab import *
2423
from .backends import save_trace, load_trace
2524

pymc3/backends/report.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class SamplerReport:
4040
def __init__(self):
4141
self._chain_warnings = {}
4242
self._global_warnings = []
43-
self._effective_n = None
44-
self._gelman_rubin = None
43+
self._ess = None
44+
self._rhat = None
4545

4646
@property
4747
def _warnings(self):
@@ -69,7 +69,7 @@ def _run_convergence_checks(self, trace, model):
6969
self._add_warnings([warn])
7070
return
7171

72-
from pymc3 import diagnostics
72+
from pymc3 import rhat, ess
7373

7474
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
7575
varnames = []
@@ -81,50 +81,50 @@ def _run_convergence_checks(self, trace, model):
8181
if rv_name in trace.varnames:
8282
varnames.append(rv_name)
8383

84-
self._effective_n = effective_n = diagnostics.effective_n(trace, varnames)
85-
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace, varnames)
84+
self._ess = ess = ess(trace, var_names=varnames)
85+
self._rhat = rhat = rhat(trace, var_names=varnames)
8686

8787
warnings = []
88-
rhat_max = max(val.max() for val in gelman_rubin.values())
88+
rhat_max = max(val.max() for val in rhat.values())
8989
if rhat_max > 1.4:
90-
msg = ("The gelman-rubin statistic is larger than 1.4 for some "
90+
msg = ("The rhat statistic is larger than 1.4 for some "
9191
"parameters. The sampler did not converge.")
9292
warn = SamplerWarning(
93-
WarningType.CONVERGENCE, msg, 'error', None, None, gelman_rubin)
93+
WarningType.CONVERGENCE, msg, 'error', None, None, rhat)
9494
warnings.append(warn)
9595
elif rhat_max > 1.2:
96-
msg = ("The gelman-rubin statistic is larger than 1.2 for some "
96+
msg = ("The rhat statistic is larger than 1.2 for some "
9797
"parameters.")
9898
warn = SamplerWarning(
99-
WarningType.CONVERGENCE, msg, 'warn', None, None, gelman_rubin)
99+
WarningType.CONVERGENCE, msg, 'warn', None, None, rhat)
100100
warnings.append(warn)
101101
elif rhat_max > 1.05:
102-
msg = ("The gelman-rubin statistic is larger than 1.05 for some "
102+
msg = ("The rhat statistic is larger than 1.05 for some "
103103
"parameters. This indicates slight problems during "
104104
"sampling.")
105105
warn = SamplerWarning(
106-
WarningType.CONVERGENCE, msg, 'info', None, None, gelman_rubin)
106+
WarningType.CONVERGENCE, msg, 'info', None, None, rhat)
107107
warnings.append(warn)
108108

109-
eff_min = min(val.min() for val in effective_n.values())
109+
eff_min = min(val.min() for val in ess.values())
110110
n_samples = len(trace) * trace.nchains
111111
if eff_min < 200 and n_samples >= 500:
112112
msg = ("The estimated number of effective samples is smaller than "
113113
"200 for some parameters.")
114114
warn = SamplerWarning(
115-
WarningType.CONVERGENCE, msg, 'error', None, None, effective_n)
115+
WarningType.CONVERGENCE, msg, 'error', None, None, ess)
116116
warnings.append(warn)
117117
elif eff_min / n_samples < 0.1:
118118
msg = ("The number of effective samples is smaller than "
119119
"10% for some parameters.")
120120
warn = SamplerWarning(
121-
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
121+
WarningType.CONVERGENCE, msg, 'warn', None, None, ess)
122122
warnings.append(warn)
123123
elif eff_min / n_samples < 0.25:
124124
msg = ("The number of effective samples is smaller than "
125125
"25% for some parameters.")
126126
warn = SamplerWarning(
127-
WarningType.CONVERGENCE, msg, 'info', None, None, effective_n)
127+
WarningType.CONVERGENCE, msg, 'info', None, None, ess)
128128
warnings.append(warn)
129129

130130
self._add_warnings(warnings)

0 commit comments

Comments
 (0)