Skip to content

Refactor HMC and warning system #2677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
creating custom backends).
"""
import itertools as itl
import logging

import numpy as np
import warnings
import theano.tensor as tt

from ..model import modelcontext
from .report import SamplerReport, merge_reports

logger = logging.getLogger('pymc3')


class BackendError(Exception):
Expand Down Expand Up @@ -61,6 +65,10 @@ def __init__(self, name, model=None, vars=None, test_point=None):
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
self._warnings = []

def _add_warnings(self, warnings):
self._warnings.extend(warnings)

# Sampling methods

Expand Down Expand Up @@ -174,7 +182,7 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
return self._get_sampler_stats(varname, sampler_idx, burn, thin)

sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
if varname in s]
if varname in s]
if not sampler_idxs:
raise KeyError("Unknown sampler stat %s" % varname)

Expand All @@ -185,20 +193,19 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
else:
return vals


def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
"""Get sampler statistics."""
raise NotImplementedError()

def _slice(self, idx):
"""Slice trace object."""
raise NotImplementedError
raise NotImplementedError()

def point(self, idx):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError
raise NotImplementedError()

@property
def stat_names(self):
Expand Down Expand Up @@ -258,6 +265,11 @@ def __init__(self, straces):
raise ValueError("Chains are not unique.")
self._straces[strace.chain] = strace

self._report = SamplerReport()
for strace in straces:
if hasattr(strace, '_warnings'):
self._report._add_warnings(strace._warnings, strace.chain)

def __repr__(self):
template = '<{}: {} chains, {} iterations, {} variables>'
return template.format(self.__class__.__name__,
Expand All @@ -271,6 +283,10 @@ def nchains(self):
def chains(self):
return list(sorted(self._straces.keys()))

@property
def report(self):
return self._report

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)
Expand Down Expand Up @@ -303,7 +319,7 @@ def __getitem__(self, idx):
raise KeyError("Unknown variable %s" % var)

_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats'])
'supports_sampler_stats', '_report'])

def __getattr__(self, name):
# Avoid infinite recursion when called before __init__
Expand Down Expand Up @@ -447,10 +463,13 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
for chain in chains]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, idx):
def _slice(self, slice):
"""Return a new MultiTrace object sliced according to `idx`."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to idx --> according to slice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

new_traces = [trace._slice(idx) for trace in self._straces.values()]
return MultiTrace(new_traces)
new_traces = [trace._slice(slice) for trace in self._straces.values()]
trace = MultiTrace(new_traces)
idxs = slice.indices(len(self))
trace._report = self._report._slice(*idxs)
return trace

def point(self, idx, chain=None):
"""Return a dictionary of point values at `idx`.
Expand Down Expand Up @@ -502,6 +521,7 @@ def merge_traces(mtraces):
if new_chain in base_mtrace._straces:
raise ValueError("Chains are not unique.")
base_mtrace._straces[new_chain] = strace
base_mtrace.report = merge_reports([trace.report for trace in mtraces])
return base_mtrace


Expand Down
5 changes: 3 additions & 2 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def close(self):
self.samples = {var: vtrace[:self.draw_idx]
for var, vtrace in self.samples.items()}
if self._stats is not None:
self._stats = [{var: trace[:self.draw_idx] for var, trace in stats.items()}
for stats in self._stats]
self._stats = [
{var: trace[:self.draw_idx] for var, trace in stats.items()}
for stats in self._stats]

# Selection methods

Expand Down
145 changes: 145 additions & 0 deletions pymc3/backends/report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections import namedtuple
import logging


logger = logging.getLogger('pymc3')


SamplerWarning = namedtuple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might make sense to have SamplerWarning.kind be an Enum, so that there is a little more reuse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, fixed

'SamplerWarning',
"kind, message, level, step, exec_info, extra")


_LEVELS = {
'info': logging.INFO,
'error': logging.ERROR,
'warn': logging.WARN,
'debug': logging.DEBUG,
'critical': logging.CRITICAL,
}


class SamplerReport(object):
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._effective_n = None
self._gelman_rubin = None

@property
def _warnings(self):
chains = sum(self._chain_warnings.values(), [])
return chains + self._global_warnings

@property
def ok(self):
"""Whether the automatic convergence checks found serious problems."""
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
if errors:
raise ValueError('Serious convergence issues during sampling.')

def _run_convergence_checks(self, trace):
if trace.nchains == 1:
msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks")
warn = SamplerWarning('bad-params', msg, 'info', None, None, None)
self._add_warnings([warn])
return

from pymc3 import diagnostics

self._effective_n = effective_n = diagnostics.effective_n(trace)
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace)

warnings = []
rhat_max = max(val.max() for val in gelman_rubin.values())
if rhat_max > 1.4:
msg = ("The gelman-rubin statistic is larger than 1.4 for some "
"parameters. The sampler did not converge.")
warn = SamplerWarning(
'convergence', msg, 'error', None, None, gelman_rubin)
warnings.append(warn)
elif rhat_max > 1.2:
msg = ("The gelman-rubin statistic is larger than 1.2 for some "
"parameters.")
warn = SamplerWarning(
'convergence', msg, 'warn', None, None, gelman_rubin)
warnings.append(warn)
elif rhat_max > 1.05:
msg = ("The gelman-rubin statistic is larger than 1.05 for some "
"parameters. This indicates slight problems during "
"sampling.")
warn = SamplerWarning(
'convergence', msg, 'info', None, None, gelman_rubin)
warnings.append(warn)

eff_min = min(val.min() for val in effective_n.values())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be more informative to raise warning if the effect sample size is lower than a certain percentage of the total sample? For example, raise warning if effect sample size is below 25% of the total number of sample.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. We could have both, too. One with level info if it is <25% and one with level warn if it is <200?

Copy link
Member

@junpenglao junpenglao Jan 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would do the other way around:
warn if it is <25% (high autocorrelation and poor mixing is likely an indication of modelling problem)
and info if it is <200 (sometimes ppl sample only 200 for demo purpose and it would be annoying to see warning)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that comes down to how much we trust effective_n. If we don't trust the numbers it returns, but only use it as an indication that the sampler doesn't work well, then I think I agree. But if we take the values from effective_n at face value, then a low percentage of effective samples isn't in itself a problem (ie should only be info). But a very low number of eff samples is a problem.
Say your samples have some autocorrelation, but you don't get any divergences, and gelman_rubin looks fine. Then I don't see a problem with just running the sampler for a very long time, until you get a lot of effective samples. This seems to me like a valid use-case, and we'd print warnings if people did this, even though what they are doing is ok. (at least I think it is, right?)
On the other hand, few effective samples are always a problem if you plan to do anything with your trace.
Maybe a middle way would be:

  • info warning if neff < 0.25 * draws
  • warn if neff < 200 and draws > 500

That way we don't issue a low-neff warning if users explicitly asked for a low number of samples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good point, I agree with this solution.

if eff_min < 100:
msg = ("The estimated number of effective samples is smaller than "
"100 for some parameters.")
warn = SamplerWarning(
'convergence', msg, 'error', None, None, effective_n)
warnings.append(warn)
elif eff_min < 300:
msg = ("The estimated number of effective samples is smaller than "
"300 for some parameters.")
warn = SamplerWarning(
'convergence', msg, 'warn', None, None, effective_n)
warnings.append(warn)

self._add_warnings(warnings)

def _add_warnings(self, warnings, chain=None):
if chain is None:
warn_list = self._global_warnings
else:
warn_list = self._chain_warnings.setdefault(chain, [])
warn_list.extend(warnings)

def _log_summary(self):

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)

for chain, warns in self._chain_warnings.items():
for warn in warns:
log_warning(warn)
for warn in self._global_warnings:
log_warning(warn)

def _slice(self, start, stop, step):
report = SamplerReport()

def filter_warns(warnings):
filtered = []
for warn in warnings:
if warn.step is None:
filtered.append(warn)
elif (start <= warn.step < stop and
(warn.step - start) % step == 0):
warn = warn._replace(step=warn.step - start)
filtered.append(warn)
return filtered

report._add_warnings(filter_warns(self._global_warnings))
for chain in self._chain_warnings:
report._add_warnings(
filter_warns(self._chain_warnings[chain]),
chain)

return report


def merge_reports(reports):
report = SamplerReport()
for rep in reports:
report._add_warnings(rep._global_warnings)
for chain in rep._chain_warnings:
report._add_warnings(rep._chain_warnings[chain], chain)
return report
4 changes: 2 additions & 2 deletions pymc3/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def get_neff(x, Vhat):
if t % 2:
t -= 1

return min(num_chains * num_samples,
int(num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())))
neff = num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())
return min(num_chains * num_samples, np.floor(neff))

def generate_neff(trace_values):
x = np.array(trace_values)
Expand Down
14 changes: 7 additions & 7 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from theano import theano, tensor as tt
from theano.tensor.var import TensorVariable

from pymc3.theanof import set_theano_conf
from pymc3.theanof import set_theano_conf, floatX
import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize, WithMemoization
Expand Down Expand Up @@ -1061,13 +1061,13 @@ def _get_scaling(total_size, shape, ndim):
scalar
"""
if total_size is None:
coef = pm.floatX(1)
coef = floatX(1)
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = pm.floatX(total_size) / pm.floatX(denom)
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError('Unrecognized `total_size` type, expected '
Expand All @@ -1085,20 +1085,20 @@ def _get_scaling(total_size, shape, ndim):
raise ValueError('Length of `total_size` is too big, '
'number of scalings is bigger that ndim, got %r' % total_size)
elif (len(begin) + len(end)) == 0:
return pm.floatX(1)
return floatX(1)
if len(end) > 0:
shp_end = shape[-len(end):]
else:
shp_end = np.asarray([])
shp_begin = shape[:len(begin)]
begin_coef = [pm.floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [pm.floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = tt.prod(coefs)
else:
raise TypeError('Unrecognized `total_size` type, expected '
'int or list of ints, got %r' % total_size)
return tt.as_tensor(pm.floatX(coef))
return tt.as_tensor(floatX(coef))


class FreeRV(Factor, TensorVariable):
Expand Down
Loading