Skip to content

Commit bd34758

Browse files
committed
Always run divergence and treedepth checks in run_convergence_checks
1 parent 879dc0e commit bd34758

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

pymc/stats/convergence.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,37 @@ class SamplerWarning:
6262

6363

6464
def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWarning]:
65+
warnings: list[SamplerWarning] = []
66+
6567
if not hasattr(idata, "posterior"):
6668
msg = "No posterior samples. Unable to run convergence checks"
6769
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
68-
return [warn]
70+
warnings.append(warn)
71+
return warnings
72+
73+
warnings += warn_divergences(idata)
74+
warnings += warn_treedepth(idata)
6975

7076
if idata["posterior"].sizes["draw"] < 100:
7177
msg = "The number of samples is too small to check convergence reliably."
7278
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
73-
return [warn]
79+
warnings.append(warn)
80+
return warnings
7481

7582
if idata["posterior"].sizes["chain"] == 1:
7683
msg = "Only one chain was sampled, this makes it impossible to run some convergence checks"
7784
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
78-
return [warn]
85+
warnings.append(warn)
86+
return warnings
7987

8088
elif idata["posterior"].sizes["chain"] < 4:
8189
msg = (
8290
"We recommend running at least 4 chains for robust computation of "
8391
"convergence diagnostics"
8492
)
8593
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
86-
return [warn]
94+
warnings.append(warn)
8795

88-
warnings: list[SamplerWarning] = []
8996
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
9097
varnames = []
9198
for rv in model.free_RVs:
@@ -99,7 +106,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
99106
ess = arviz.ess(idata, var_names=varnames)
100107
rhat = arviz.rhat(idata, var_names=varnames)
101108

102-
warnings = []
103109
rhat_max = max(val.max() for val in rhat.values())
104110
if rhat_max > 1.01:
105111
msg = (
@@ -121,9 +127,6 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWar
121127
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
122128
warnings.append(warn)
123129

124-
warnings += warn_divergences(idata)
125-
warnings += warn_treedepth(idata)
126-
127130
return warnings
128131

129132

0 commit comments

Comments
 (0)