@@ -62,30 +62,37 @@ class SamplerWarning:
62
62
63
63
64
64
def run_convergence_checks (idata : arviz .InferenceData , model ) -> list [SamplerWarning ]:
65
+ warnings : list [SamplerWarning ] = []
66
+
65
67
if not hasattr (idata , "posterior" ):
66
68
msg = "No posterior samples. Unable to run convergence checks"
67
69
warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" , None , None , None )
68
- return [warn ]
70
+ warnings .append (warn )
71
+ return warnings
72
+
73
+ warnings += warn_divergences (idata )
74
+ warnings += warn_treedepth (idata )
69
75
70
76
if idata ["posterior" ].sizes ["draw" ] < 100 :
71
77
msg = "The number of samples is too small to check convergence reliably."
72
78
warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" , None , None , None )
73
- return [warn ]
79
+ warnings .append (warn )
80
+ return warnings
74
81
75
82
if idata ["posterior" ].sizes ["chain" ] == 1 :
76
83
msg = "Only one chain was sampled, this makes it impossible to run some convergence checks"
77
84
warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
78
- return [warn ]
85
+ warnings .append (warn )
86
+ return warnings
79
87
80
88
elif idata ["posterior" ].sizes ["chain" ] < 4 :
81
89
msg = (
82
90
"We recommend running at least 4 chains for robust computation of "
83
91
"convergence diagnostics"
84
92
)
85
93
warn = SamplerWarning (WarningType .BAD_PARAMS , msg , "info" )
86
- return [ warn ]
94
+ warnings . append ( warn )
87
95
88
- warnings : list [SamplerWarning ] = []
89
96
valid_name = [rv .name for rv in model .free_RVs + model .deterministics ]
90
97
varnames = []
91
98
for rv in model .free_RVs :
@@ -99,7 +106,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
99
106
ess = arviz .ess (idata , var_names = varnames )
100
107
rhat = arviz .rhat (idata , var_names = varnames )
101
108
102
- warnings = []
103
109
rhat_max = max (val .max () for val in rhat .values ())
104
110
if rhat_max > 1.01 :
105
111
msg = (
@@ -121,9 +127,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
121
127
warn = SamplerWarning (WarningType .CONVERGENCE , msg , "error" , extra = ess )
122
128
warnings .append (warn )
123
129
124
- warnings += warn_divergences (idata )
125
- warnings += warn_treedepth (idata )
126
-
127
130
return warnings
128
131
129
132
0 commit comments