Skip to content

Commit cc55279

Browse files
authored
Merge pull request #3670 from rpgoldman/impute-warning
Add ImputationWarning class.
2 parents 15eb75e + ee4bd61 commit cc55279

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

pymc3/exceptions.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
__all__ = ['SamplingError', 'IncorrectArgumentsError', 'TraceDirectoryError']
1+
__all__ = [
2+
"SamplingError",
3+
"IncorrectArgumentsError",
4+
"TraceDirectoryError",
5+
"ImputationWarning",
6+
]
27

38

49
class SamplingError(RuntimeError):
@@ -8,6 +13,14 @@ class SamplingError(RuntimeError):
813
class IncorrectArgumentsError(ValueError):
914
pass
1015

16+
1117
class TraceDirectoryError(ValueError):
12-
'''Error from trying to load a trace from an incorrectly-structured directory,'''
18+
"""Error from trying to load a trace from an incorrectly-structured directory,"""
19+
20+
pass
21+
22+
23+
class ImputationWarning(UserWarning):
24+
"""Warning that there are missing values that will be imputed."""
25+
1326
pass

pymc3/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
2222
from .blocking import DictToArrayBijection, ArrayOrdering
2323
from .util import get_transformed_name
24+
from .exceptions import ImputationWarning
2425

2526
__all__ = [
2627
'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext',
@@ -1341,7 +1342,7 @@ def as_tensor(data, name, model, distribution):
13411342
impute_message = ('Data in {name} contains missing values and'
13421343
' will be automatically imputed from the'
13431344
' sampling distribution.'.format(name=name))
1344-
warnings.warn(impute_message, UserWarning)
1345+
warnings.warn(impute_message, ImputationWarning)
13451346
from .distributions import NoDistribution
13461347
testval = np.broadcast_to(distribution.default(), data.shape)[data.mask]
13471348
fakedist = NoDistribution.dist(shape=data.mask.sum(), dtype=dtype,

0 commit comments

Comments
 (0)