From c0e7cd5c8327d0324f5f23d28084a0671095b273 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 17:38:30 +0100 Subject: [PATCH 1/6] Add flag to disable bounds check. --- pymc3/distributions/dist_math.py | 8 ++++++++ pymc3/model.py | 10 +++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index e302000735..4e6d639d1a 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -33,6 +33,7 @@ from pymc3.distributions.shape_utils import to_tuple from pymc3.distributions.special import gammaln +from pymc3.model import modelcontext from pymc3.theanof import floatX f = floatX @@ -67,6 +68,13 @@ def bound(logp, *conditions, **kwargs): ------- logp with elements set to -inf where any condition is False """ + + # If called inside a model context, see if bounds check is disabled + model = modelcontext(kwargs.get("model")) + if model is not None: + if model.disable_bounds_check: + return logp + broadcast_conditions = kwargs.get("broadcast_conditions", True) if broadcast_conditions: diff --git a/pymc3/model.py b/pymc3/model.py index 14d7244278..6a6e159b6c 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta): temporarily in the model context. See the documentation of theano for a complete list. Set config key ``compute_test_value`` to `raise` if it is None. + disable_bounds_check: bool + Disable checks that ensure that input parameters to distributions + are in a valid range. If your model is built in a way where you + know your parameters can only take on valid values you can disable + this for increased speed. Examples -------- @@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs): instance._theano_config = theano_config return instance - def __init__(self, name="", model=None, theano_config=None, coords=None): + def __init__( + self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False + ): self.name = name self.coords = {} self.RV_dims = {} self.add_coords(coords) + self.disable_bounds_check = disable_bounds_check if self.parent is not None: self.named_vars = treedict(parent=self.parent.named_vars) From 1ef552af8666f6cab925bdd1d385d1d352729f79 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 18:22:48 +0100 Subject: [PATCH 2/6] Move check for model and flag into single line as per @colcarrolls's suggestion. --- pymc3/distributions/dist_math.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 4e6d639d1a..5c50538327 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -71,9 +71,8 @@ def bound(logp, *conditions, **kwargs): # If called inside a model context, see if bounds check is disabled model = modelcontext(kwargs.get("model")) - if model is not None: - if model.disable_bounds_check: - return logp + if model is not None and model.disable_bounds_check: + return logp broadcast_conditions = kwargs.get("broadcast_conditions", True) From 625cde9b82efa49151057cf172152f54da398398 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 19:04:29 +0100 Subject: [PATCH 3/6] modelcontext raises a TypeError if no model is found. Catch that. --- pymc3/distributions/dist_math.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 5c50538327..27bc0292ff 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -70,9 +70,12 @@ def bound(logp, *conditions, **kwargs): """ # If called inside a model context, see if bounds check is disabled - model = modelcontext(kwargs.get("model")) - if model is not None and model.disable_bounds_check: - return logp + try: + model = modelcontext(kwargs.get("model")) + if model.disable_bounds_check: + return logp + except TypeError: + pass broadcast_conditions = kwargs.get("broadcast_conditions", True) From 78dcdcc57fe01debe7b6ad51d5df3ee89c011dd7 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 19:06:10 +0100 Subject: [PATCH 4/6] Add comment. --- pymc3/distributions/dist_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 27bc0292ff..dc67f89652 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -74,7 +74,7 @@ def bound(logp, *conditions, **kwargs): model = modelcontext(kwargs.get("model")) if model.disable_bounds_check: return logp - except TypeError: + except TypeError: # No model found pass broadcast_conditions = kwargs.get("broadcast_conditions", True) From 62db6d1b8221c9458439f0bcbd1af46161dadd84 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 19:48:37 +0100 Subject: [PATCH 5/6] Add mention in release-notes. --- RELEASE-NOTES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 4cbb2a6c61..8f415b0932 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -16,6 +16,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang - Removed `theanof.set_theano_config` because it illegally changed Theano's internal state (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)). ### New Features +- Option to `disable_bounds_check=True` when instantiating `pymc3.Model()` for faster sampling for models that cannot violate boundary constraints (which are most of them; see [#4377](https://github.com/pymc-devs/pymc3/pull/4377)). - `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)). - `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234)) From 8af44f23adacfbd4e17a5774af88cd6b61957b3c Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 23 Dec 2020 19:49:17 +0100 Subject: [PATCH 6/6] Add test for when boundaries are disabled. --- pymc3/tests/test_dist_math.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pymc3/tests/test_dist_math.py b/pymc3/tests/test_dist_math.py index 9aa9dbf16c..359cc8d656 100644 --- a/pymc3/tests/test_dist_math.py +++ b/pymc3/tests/test_dist_math.py @@ -60,6 +60,13 @@ def test_bound(): assert np.prod(bound(logp, cond).eval()) == -np.inf +def test_bound_disabled(): + with pm.Model(disable_bounds_check=True): + logp = tt.ones(3) + cond = np.array([1, 0, 1]) + assert np.all(bound(logp, cond).eval() == logp.eval()) + + def test_alltrue_scalar(): assert alltrue_scalar([]).eval() assert alltrue_scalar([True]).eval()