@@ -40,8 +40,8 @@ class SamplerReport:
40
40
def __init__ (self ):
41
41
self ._chain_warnings = {}
42
42
self ._global_warnings = []
43
- self ._effective_n = None
44
- self ._gelman_rubin = None
43
+ self ._ess = None
44
+ self ._rhat = None
45
45
46
46
@property
47
47
def _warnings (self ):
@@ -69,7 +69,7 @@ def _run_convergence_checks(self, trace, model):
69
69
self ._add_warnings ([warn ])
70
70
return
71
71
72
- from pymc3 import diagnostics
72
+ from pymc3 import rhat , ess
73
73
74
74
valid_name = [rv .name for rv in model .free_RVs + model .deterministics ]
75
75
varnames = []
@@ -81,50 +81,50 @@ def _run_convergence_checks(self, trace, model):
81
81
if rv_name in trace .varnames :
82
82
varnames .append (rv_name )
83
83
84
- self ._effective_n = effective_n = diagnostics . effective_n (trace , varnames )
85
- self ._gelman_rubin = gelman_rubin = diagnostics . gelman_rubin (trace , varnames )
84
+ self ._ess = ess = ess (trace , var_names = varnames )
85
+ self ._rhat = rhat = rhat (trace , var_names = varnames )
86
86
87
87
warnings = []
88
- rhat_max = max (val .max () for val in gelman_rubin .values ())
88
+ rhat_max = max (val .max () for val in rhat .values ())
89
89
if rhat_max > 1.4 :
90
- msg = ("The gelman-rubin statistic is larger than 1.4 for some "
90
+ msg = ("The rhat statistic is larger than 1.4 for some "
91
91
"parameters. The sampler did not converge." )
92
92
warn = SamplerWarning (
93
- WarningType .CONVERGENCE , msg , 'error' , None , None , gelman_rubin )
93
+ WarningType .CONVERGENCE , msg , 'error' , None , None , rhat )
94
94
warnings .append (warn )
95
95
elif rhat_max > 1.2 :
96
- msg = ("The gelman-rubin statistic is larger than 1.2 for some "
96
+ msg = ("The rhat statistic is larger than 1.2 for some "
97
97
"parameters." )
98
98
warn = SamplerWarning (
99
- WarningType .CONVERGENCE , msg , 'warn' , None , None , gelman_rubin )
99
+ WarningType .CONVERGENCE , msg , 'warn' , None , None , rhat )
100
100
warnings .append (warn )
101
101
elif rhat_max > 1.05 :
102
- msg = ("The gelman-rubin statistic is larger than 1.05 for some "
102
+ msg = ("The rhat statistic is larger than 1.05 for some "
103
103
"parameters. This indicates slight problems during "
104
104
"sampling." )
105
105
warn = SamplerWarning (
106
- WarningType .CONVERGENCE , msg , 'info' , None , None , gelman_rubin )
106
+ WarningType .CONVERGENCE , msg , 'info' , None , None , rhat )
107
107
warnings .append (warn )
108
108
109
- eff_min = min (val .min () for val in effective_n .values ())
109
+ eff_min = min (val .min () for val in ess .values ())
110
110
n_samples = len (trace ) * trace .nchains
111
111
if eff_min < 200 and n_samples >= 500 :
112
112
msg = ("The estimated number of effective samples is smaller than "
113
113
"200 for some parameters." )
114
114
warn = SamplerWarning (
115
- WarningType .CONVERGENCE , msg , 'error' , None , None , effective_n )
115
+ WarningType .CONVERGENCE , msg , 'error' , None , None , ess )
116
116
warnings .append (warn )
117
117
elif eff_min / n_samples < 0.1 :
118
118
msg = ("The number of effective samples is smaller than "
119
119
"10% for some parameters." )
120
120
warn = SamplerWarning (
121
- WarningType .CONVERGENCE , msg , 'warn' , None , None , effective_n )
121
+ WarningType .CONVERGENCE , msg , 'warn' , None , None , ess )
122
122
warnings .append (warn )
123
123
elif eff_min / n_samples < 0.25 :
124
124
msg = ("The number of effective samples is smaller than "
125
125
"25% for some parameters." )
126
126
warn = SamplerWarning (
127
- WarningType .CONVERGENCE , msg , 'info' , None , None , effective_n )
127
+ WarningType .CONVERGENCE , msg , 'info' , None , None , ess )
128
128
warnings .append (warn )
129
129
130
130
self ._add_warnings (warnings )
0 commit comments