-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor HMC and warning system #2677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
16a279b
ff4535f
1429d58
b34b5e7
1fa4b8f
3eb6a66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from collections import namedtuple | ||
import logging | ||
|
||
|
||
logger = logging.getLogger('pymc3') | ||
|
||
|
||
SamplerWarning = namedtuple( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might make sense to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, fixed |
||
'SamplerWarning', | ||
"kind, message, level, step, exec_info, extra") | ||
|
||
|
||
_LEVELS = { | ||
'info': logging.INFO, | ||
'error': logging.ERROR, | ||
'warn': logging.WARN, | ||
'debug': logging.DEBUG, | ||
'critical': logging.CRITICAL, | ||
} | ||
|
||
|
||
class SamplerReport(object): | ||
def __init__(self): | ||
self._chain_warnings = {} | ||
self._global_warnings = [] | ||
self._effective_n = None | ||
self._gelman_rubin = None | ||
|
||
@property | ||
def _warnings(self): | ||
chains = sum(self._chain_warnings.values(), []) | ||
return chains + self._global_warnings | ||
|
||
@property | ||
def ok(self): | ||
"""Whether the automatic convergence checks found serious problems.""" | ||
return all(_LEVELS[warn.level] < _LEVELS['warn'] | ||
for warn in self._warnings) | ||
|
||
def raise_ok(self, level='error'): | ||
errors = [warn for warn in self._warnings | ||
if _LEVELS[warn.level] >= _LEVELS[level]] | ||
if errors: | ||
raise ValueError('Serious convergence issues during sampling.') | ||
|
||
def _run_convergence_checks(self, trace): | ||
if trace.nchains == 1: | ||
msg = ("Only one chain was sampled, this makes it impossible to " | ||
"run some convergence checks") | ||
warn = SamplerWarning('bad-params', msg, 'info', None, None, None) | ||
self._add_warnings([warn]) | ||
return | ||
|
||
from pymc3 import diagnostics | ||
|
||
self._effective_n = effective_n = diagnostics.effective_n(trace) | ||
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace) | ||
|
||
warnings = [] | ||
rhat_max = max(val.max() for val in gelman_rubin.values()) | ||
if rhat_max > 1.4: | ||
msg = ("The gelman-rubin statistic is larger than 1.4 for some " | ||
"parameters. The sampler did not converge.") | ||
warn = SamplerWarning( | ||
'convergence', msg, 'error', None, None, gelman_rubin) | ||
warnings.append(warn) | ||
elif rhat_max > 1.2: | ||
msg = ("The gelman-rubin statistic is larger than 1.2 for some " | ||
"parameters.") | ||
warn = SamplerWarning( | ||
'convergence', msg, 'warn', None, None, gelman_rubin) | ||
warnings.append(warn) | ||
elif rhat_max > 1.05: | ||
msg = ("The gelman-rubin statistic is larger than 1.05 for some " | ||
"parameters. This indicates slight problems during " | ||
"sampling.") | ||
warn = SamplerWarning( | ||
'convergence', msg, 'info', None, None, gelman_rubin) | ||
warnings.append(warn) | ||
|
||
eff_min = min(val.min() for val in effective_n.values()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be more informative to raise warning if the effect sample size is lower than a certain percentage of the total sample? For example, raise warning if effect sample size is below 25% of the total number of sample. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. We could have both, too. One with level There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would do the other way around: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that comes down to how much we trust
That way we don't issue a low-neff warning if users explicitly asked for a low number of samples. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a good point, I agree with this solution. |
||
if eff_min < 100: | ||
msg = ("The estimated number of effective samples is smaller than " | ||
"100 for some parameters.") | ||
warn = SamplerWarning( | ||
'convergence', msg, 'error', None, None, effective_n) | ||
warnings.append(warn) | ||
elif eff_min < 300: | ||
msg = ("The estimated number of effective samples is smaller than " | ||
"300 for some parameters.") | ||
warn = SamplerWarning( | ||
'convergence', msg, 'warn', None, None, effective_n) | ||
warnings.append(warn) | ||
|
||
self._add_warnings(warnings) | ||
|
||
def _add_warnings(self, warnings, chain=None): | ||
if chain is None: | ||
warn_list = self._global_warnings | ||
else: | ||
warn_list = self._chain_warnings.setdefault(chain, []) | ||
warn_list.extend(warnings) | ||
|
||
def _log_summary(self): | ||
|
||
def log_warning(warn): | ||
level = _LEVELS[warn.level] | ||
logger.log(level, warn.message) | ||
|
||
for chain, warns in self._chain_warnings.items(): | ||
for warn in warns: | ||
log_warning(warn) | ||
for warn in self._global_warnings: | ||
log_warning(warn) | ||
|
||
def _slice(self, start, stop, step): | ||
report = SamplerReport() | ||
|
||
def filter_warns(warnings): | ||
filtered = [] | ||
for warn in warnings: | ||
if warn.step is None: | ||
filtered.append(warn) | ||
elif (start <= warn.step < stop and | ||
(warn.step - start) % step == 0): | ||
warn = warn._replace(step=warn.step - start) | ||
filtered.append(warn) | ||
return filtered | ||
|
||
report._add_warnings(filter_warns(self._global_warnings)) | ||
for chain in self._chain_warnings: | ||
report._add_warnings( | ||
filter_warns(self._chain_warnings[chain]), | ||
chain) | ||
|
||
return report | ||
|
||
|
||
def merge_reports(reports): | ||
report = SamplerReport() | ||
for rep in reports: | ||
report._add_warnings(rep._global_warnings) | ||
for chain in rep._chain_warnings: | ||
report._add_warnings(rep._chain_warnings[chain], chain) | ||
return report |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to
idx
--> according toslice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed