From ee4bd61c871dbd2f90362f1c61889d2a68b0d4c6 Mon Sep 17 00:00:00 2001 From: "Robert P. Goldman" Date: Mon, 28 Oct 2019 20:57:47 -0500 Subject: [PATCH] Add ImputationWarning class. The idea is that a programmer be able to ignore imputation warnings if they know that data is being imputed. It's easier to do this with a distinct class than with just UserWarning. --- pymc3/exceptions.py | 17 +++++++++++++++-- pymc3/model.py | 3 ++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pymc3/exceptions.py b/pymc3/exceptions.py index b2ff9f0c52..de1bf9cd78 100644 --- a/pymc3/exceptions.py +++ b/pymc3/exceptions.py @@ -1,4 +1,9 @@ -__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError'] +__all__ = [ + "SamplingError", + "IncorrectArgumentsError", + "TraceDirectoryError", + "ImputationWarning", +] class SamplingError(RuntimeError): @@ -8,6 +13,14 @@ class SamplingError(RuntimeError): class IncorrectArgumentsError(ValueError): pass + class TraceDirectoryError(ValueError): - '''Error from trying to load a trace from an incorrectly-structured directory,''' + """Error from trying to load a trace from an incorrectly-structured directory,""" + + pass + + +class ImputationWarning(UserWarning): + """Warning that there are missing values that will be imputed.""" + pass diff --git a/pymc3/model.py b/pymc3/model.py index 00423e5514..6ffb8053d8 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -21,6 +21,7 @@ from .vartypes import typefilter, discrete_types, continuous_types, isgenerator from .blocking import DictToArrayBijection, ArrayOrdering from .util import get_transformed_name +from .exceptions import ImputationWarning __all__ = [ 'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext', @@ -1341,7 +1342,7 @@ def as_tensor(data, name, model, distribution): impute_message = ('Data in {name} contains missing values and' ' will be automatically imputed from the' ' sampling distribution.'.format(name=name)) - warnings.warn(impute_message, UserWarning) + warnings.warn(impute_message, ImputationWarning) from .distributions import NoDistribution testval = np.broadcast_to(distribution.default(), data.shape)[data.mask] fakedist = NoDistribution.dist(shape=data.mask.sum(), dtype=dtype,