Skip to content

Commit 39cd75d

Browse files
aseyboldtJunpeng Lao
authored and
Junpeng Lao
committed
Refactor HMC and warning system (#2677)
* Move DualAveraging out of nuts * Use dual averaging in hmc * Update doc of dual average parameters * Rewrite warning system * Add stats to hmc * Change SamplerWarning.kind to enum
1 parent a547cd5 commit 39cd75d

18 files changed

+731
-380
lines changed

pymc3/backends/base.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
creating custom backends).
55
"""
66
import itertools as itl
7+
import logging
78

89
import numpy as np
910
import warnings
1011
import theano.tensor as tt
1112

1213
from ..model import modelcontext
14+
from .report import SamplerReport, merge_reports
15+
16+
logger = logging.getLogger('pymc3')
1317

1418

1519
class BackendError(Exception):
@@ -61,6 +65,10 @@ def __init__(self, name, model=None, vars=None, test_point=None):
6165
self.chain = None
6266
self._is_base_setup = False
6367
self.sampler_vars = None
68+
self._warnings = []
69+
70+
def _add_warnings(self, warnings):
71+
self._warnings.extend(warnings)
6472

6573
# Sampling methods
6674

@@ -174,7 +182,7 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
174182
return self._get_sampler_stats(varname, sampler_idx, burn, thin)
175183

176184
sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
177-
if varname in s]
185+
if varname in s]
178186
if not sampler_idxs:
179187
raise KeyError("Unknown sampler stat %s" % varname)
180188

@@ -185,20 +193,19 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
185193
else:
186194
return vals
187195

188-
189196
def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
190197
"""Get sampler statistics."""
191198
raise NotImplementedError()
192199

193200
def _slice(self, idx):
194201
"""Slice trace object."""
195-
raise NotImplementedError
202+
raise NotImplementedError()
196203

197204
def point(self, idx):
198205
"""Return dictionary of point values at `idx` for current chain
199206
with variables names as keys.
200207
"""
201-
raise NotImplementedError
208+
raise NotImplementedError()
202209

203210
@property
204211
def stat_names(self):
@@ -258,6 +265,11 @@ def __init__(self, straces):
258265
raise ValueError("Chains are not unique.")
259266
self._straces[strace.chain] = strace
260267

268+
self._report = SamplerReport()
269+
for strace in straces:
270+
if hasattr(strace, '_warnings'):
271+
self._report._add_warnings(strace._warnings, strace.chain)
272+
261273
def __repr__(self):
262274
template = '<{}: {} chains, {} iterations, {} variables>'
263275
return template.format(self.__class__.__name__,
@@ -271,6 +283,10 @@ def nchains(self):
271283
def chains(self):
272284
return list(sorted(self._straces.keys()))
273285

286+
@property
287+
def report(self):
288+
return self._report
289+
274290
def __getitem__(self, idx):
275291
if isinstance(idx, slice):
276292
return self._slice(idx)
@@ -303,7 +319,7 @@ def __getitem__(self, idx):
303319
raise KeyError("Unknown variable %s" % var)
304320

305321
_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
306-
'supports_sampler_stats'])
322+
'supports_sampler_stats', '_report'])
307323

308324
def __getattr__(self, name):
309325
# Avoid infinite recursion when called before __init__
@@ -447,10 +463,13 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
447463
for chain in chains]
448464
return _squeeze_cat(results, combine, squeeze)
449465

450-
def _slice(self, idx):
451-
"""Return a new MultiTrace object sliced according to `idx`."""
452-
new_traces = [trace._slice(idx) for trace in self._straces.values()]
453-
return MultiTrace(new_traces)
466+
def _slice(self, slice):
467+
"""Return a new MultiTrace object sliced according to `slice`."""
468+
new_traces = [trace._slice(slice) for trace in self._straces.values()]
469+
trace = MultiTrace(new_traces)
470+
idxs = slice.indices(len(self))
471+
trace._report = self._report._slice(*idxs)
472+
return trace
454473

455474
def point(self, idx, chain=None):
456475
"""Return a dictionary of point values at `idx`.
@@ -502,6 +521,7 @@ def merge_traces(mtraces):
502521
if new_chain in base_mtrace._straces:
503522
raise ValueError("Chains are not unique.")
504523
base_mtrace._straces[new_chain] = strace
524+
base_mtrace.report = merge_reports([trace.report for trace in mtraces])
505525
return base_mtrace
506526

507527

pymc3/backends/ndarray.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ def close(self):
116116
self.samples = {var: vtrace[:self.draw_idx]
117117
for var, vtrace in self.samples.items()}
118118
if self._stats is not None:
119-
self._stats = [{var: trace[:self.draw_idx] for var, trace in stats.items()}
120-
for stats in self._stats]
119+
self._stats = [
120+
{var: trace[:self.draw_idx] for var, trace in stats.items()}
121+
for stats in self._stats]
121122

122123
# Selection methods
123124

pymc3/backends/report.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from collections import namedtuple
2+
import logging
3+
import enum
4+
5+
6+
logger = logging.getLogger('pymc3')
7+
8+
9+
@enum.unique
10+
class WarningType(enum.Enum):
11+
# For HMC and NUTS
12+
DIVERGENCE = 1
13+
TUNING_DIVERGENCE = 2
14+
DIVERGENCES = 3
15+
TREEDEPTH = 4
16+
# Problematic sampler parameters
17+
BAD_PARAMS = 5
18+
# Indications that chains did not converge, eg Rhat
19+
CONVERGENCE = 6
20+
BAD_ACCEPTANCE = 7
21+
22+
23+
SamplerWarning = namedtuple(
24+
'SamplerWarning',
25+
"kind, message, level, step, exec_info, extra")
26+
27+
28+
_LEVELS = {
29+
'info': logging.INFO,
30+
'error': logging.ERROR,
31+
'warn': logging.WARN,
32+
'debug': logging.DEBUG,
33+
'critical': logging.CRITICAL,
34+
}
35+
36+
37+
class SamplerReport(object):
38+
def __init__(self):
39+
self._chain_warnings = {}
40+
self._global_warnings = []
41+
self._effective_n = None
42+
self._gelman_rubin = None
43+
44+
@property
45+
def _warnings(self):
46+
chains = sum(self._chain_warnings.values(), [])
47+
return chains + self._global_warnings
48+
49+
@property
50+
def ok(self):
51+
"""Whether the automatic convergence checks found serious problems."""
52+
return all(_LEVELS[warn.level] < _LEVELS['warn']
53+
for warn in self._warnings)
54+
55+
def raise_ok(self, level='error'):
56+
errors = [warn for warn in self._warnings
57+
if _LEVELS[warn.level] >= _LEVELS[level]]
58+
if errors:
59+
raise ValueError('Serious convergence issues during sampling.')
60+
61+
def _run_convergence_checks(self, trace):
62+
if trace.nchains == 1:
63+
msg = ("Only one chain was sampled, this makes it impossible to "
64+
"run some convergence checks")
65+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
66+
None, None, None)
67+
self._add_warnings([warn])
68+
return
69+
70+
from pymc3 import diagnostics
71+
72+
self._effective_n = effective_n = diagnostics.effective_n(trace)
73+
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace)
74+
75+
warnings = []
76+
rhat_max = max(val.max() for val in gelman_rubin.values())
77+
if rhat_max > 1.4:
78+
msg = ("The gelman-rubin statistic is larger than 1.4 for some "
79+
"parameters. The sampler did not converge.")
80+
warn = SamplerWarning(
81+
WarningType.CONVERGENCE, msg, 'error', None, None, gelman_rubin)
82+
warnings.append(warn)
83+
elif rhat_max > 1.2:
84+
msg = ("The gelman-rubin statistic is larger than 1.2 for some "
85+
"parameters.")
86+
warn = SamplerWarning(
87+
WarningType.CONVERGENCE, msg, 'warn', None, None, gelman_rubin)
88+
warnings.append(warn)
89+
elif rhat_max > 1.05:
90+
msg = ("The gelman-rubin statistic is larger than 1.05 for some "
91+
"parameters. This indicates slight problems during "
92+
"sampling.")
93+
warn = SamplerWarning(
94+
WarningType.CONVERGENCE, msg, 'info', None, None, gelman_rubin)
95+
warnings.append(warn)
96+
97+
eff_min = min(val.min() for val in effective_n.values())
98+
n_samples = len(trace) * trace.nchains
99+
if eff_min < 200 and n_samples >= 500:
100+
msg = ("The estimated number of effective samples is smaller than "
101+
"200 for some parameters.")
102+
warn = SamplerWarning(
103+
WarningType.CONVERGENCE, msg, 'error', None, None, effective_n)
104+
warnings.append(warn)
105+
elif eff_min / n_samples < 0.25:
106+
msg = ("The number of effective samples is smaller than "
107+
"25% for some parameters.")
108+
warn = SamplerWarning(
109+
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
110+
warnings.append(warn)
111+
112+
self._add_warnings(warnings)
113+
114+
def _add_warnings(self, warnings, chain=None):
115+
if chain is None:
116+
warn_list = self._global_warnings
117+
else:
118+
warn_list = self._chain_warnings.setdefault(chain, [])
119+
warn_list.extend(warnings)
120+
121+
def _log_summary(self):
122+
123+
def log_warning(warn):
124+
level = _LEVELS[warn.level]
125+
logger.log(level, warn.message)
126+
127+
for chain, warns in self._chain_warnings.items():
128+
for warn in warns:
129+
log_warning(warn)
130+
for warn in self._global_warnings:
131+
log_warning(warn)
132+
133+
def _slice(self, start, stop, step):
134+
report = SamplerReport()
135+
136+
def filter_warns(warnings):
137+
filtered = []
138+
for warn in warnings:
139+
if warn.step is None:
140+
filtered.append(warn)
141+
elif (start <= warn.step < stop and
142+
(warn.step - start) % step == 0):
143+
warn = warn._replace(step=warn.step - start)
144+
filtered.append(warn)
145+
return filtered
146+
147+
report._add_warnings(filter_warns(self._global_warnings))
148+
for chain in self._chain_warnings:
149+
report._add_warnings(
150+
filter_warns(self._chain_warnings[chain]),
151+
chain)
152+
153+
return report
154+
155+
156+
def merge_reports(reports):
157+
report = SamplerReport()
158+
for rep in reports:
159+
report._add_warnings(rep._global_warnings)
160+
for chain in rep._chain_warnings:
161+
report._add_warnings(rep._chain_warnings[chain], chain)
162+
return report

pymc3/diagnostics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def get_neff(x, Vhat):
250250
if t % 2:
251251
t -= 1
252252

253-
return min(num_chains * num_samples,
254-
int(num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())))
253+
neff = num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())
254+
return min(num_chains * num_samples, np.floor(neff))
255255

256256
def generate_neff(trace_values):
257257
x = np.array(trace_values)

pymc3/model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from theano import theano, tensor as tt
1111
from theano.tensor.var import TensorVariable
1212

13-
from pymc3.theanof import set_theano_conf
13+
from pymc3.theanof import set_theano_conf, floatX
1414
import pymc3 as pm
1515
from pymc3.math import flatten_list
1616
from .memoize import memoize, WithMemoization
@@ -1061,13 +1061,13 @@ def _get_scaling(total_size, shape, ndim):
10611061
scalar
10621062
"""
10631063
if total_size is None:
1064-
coef = pm.floatX(1)
1064+
coef = floatX(1)
10651065
elif isinstance(total_size, int):
10661066
if ndim >= 1:
10671067
denom = shape[0]
10681068
else:
10691069
denom = 1
1070-
coef = pm.floatX(total_size) / pm.floatX(denom)
1070+
coef = floatX(total_size) / floatX(denom)
10711071
elif isinstance(total_size, (list, tuple)):
10721072
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
10731073
raise TypeError('Unrecognized `total_size` type, expected '
@@ -1085,20 +1085,20 @@ def _get_scaling(total_size, shape, ndim):
10851085
raise ValueError('Length of `total_size` is too big, '
10861086
'number of scalings is bigger that ndim, got %r' % total_size)
10871087
elif (len(begin) + len(end)) == 0:
1088-
return pm.floatX(1)
1088+
return floatX(1)
10891089
if len(end) > 0:
10901090
shp_end = shape[-len(end):]
10911091
else:
10921092
shp_end = np.asarray([])
10931093
shp_begin = shape[:len(begin)]
1094-
begin_coef = [pm.floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
1095-
end_coef = [pm.floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
1094+
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
1095+
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
10961096
coefs = begin_coef + end_coef
10971097
coef = tt.prod(coefs)
10981098
else:
10991099
raise TypeError('Unrecognized `total_size` type, expected '
11001100
'int or list of ints, got %r' % total_size)
1101-
return tt.as_tensor(pm.floatX(coef))
1101+
return tt.as_tensor(floatX(coef))
11021102

11031103

11041104
class FreeRV(Factor, TensorVariable):

0 commit comments

Comments
 (0)