Skip to content

Commit 37239fc

Browse files
authored
Add flag to disable bounds check for speed-up (#4377)
* Add flag to disable bounds check. * Move check for model and flag into single line as per @colcarrolls's suggestion. * modelcontext raises a TypeError if no model is found. Catch that. * Add comment. * Add mention in release-notes. * Add test for when boundaries are disabled.
1 parent 0402aab commit 37239fc

File tree

4 files changed

+27
-1
lines changed

4 files changed

+27
-1
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
1616
- Removed `theanof.set_theano_config` because it illegally changed Theano's internal state (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
1717

1818
### New Features
19+
- Option to `disable_bounds_check=True` when instantiating `pymc3.Model()` for faster sampling for models that cannot violate boundary constraints (which are most of them; see [#4377](https://github.com/pymc-devs/pymc3/pull/4377)).
1920
- `OrderedProbit` distribution added (see [#4232](https://github.com/pymc-devs/pymc3/pull/4232)).
2021
- `plot_posterior_predictive_glm` now works with `arviz.InferenceData` as well (see [#4234](https://github.com/pymc-devs/pymc3/pull/4234))
2122

pymc3/distributions/dist_math.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from pymc3.distributions.shape_utils import to_tuple
3535
from pymc3.distributions.special import gammaln
36+
from pymc3.model import modelcontext
3637
from pymc3.theanof import floatX
3738

3839
f = floatX
@@ -67,6 +68,15 @@ def bound(logp, *conditions, **kwargs):
6768
-------
6869
logp with elements set to -inf where any condition is False
6970
"""
71+
72+
# If called inside a model context, see if bounds check is disabled
73+
try:
74+
model = modelcontext(kwargs.get("model"))
75+
if model.disable_bounds_check:
76+
return logp
77+
except TypeError: # No model found
78+
pass
79+
7080
broadcast_conditions = kwargs.get("broadcast_conditions", True)
7181

7282
if broadcast_conditions:

pymc3/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
809809
temporarily in the model context. See the documentation
810810
of theano for a complete list. Set config key
811811
``compute_test_value`` to `raise` if it is None.
812+
disable_bounds_check: bool
813+
Disable checks that ensure that input parameters to distributions
814+
are in a valid range. If your model is built in a way where you
815+
know your parameters can only take on valid values you can disable
816+
this for increased speed.
812817
813818
Examples
814819
--------
@@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs):
895900
instance._theano_config = theano_config
896901
return instance
897902

898-
def __init__(self, name="", model=None, theano_config=None, coords=None):
903+
def __init__(
904+
self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False
905+
):
899906
self.name = name
900907
self.coords = {}
901908
self.RV_dims = {}
902909
self.add_coords(coords)
910+
self.disable_bounds_check = disable_bounds_check
903911

904912
if self.parent is not None:
905913
self.named_vars = treedict(parent=self.parent.named_vars)

pymc3/tests/test_dist_math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def test_bound():
6060
assert np.prod(bound(logp, cond).eval()) == -np.inf
6161

6262

63+
def test_bound_disabled():
64+
with pm.Model(disable_bounds_check=True):
65+
logp = tt.ones(3)
66+
cond = np.array([1, 0, 1])
67+
assert np.all(bound(logp, cond).eval() == logp.eval())
68+
69+
6370
def test_alltrue_scalar():
6471
assert alltrue_scalar([]).eval()
6572
assert alltrue_scalar([True]).eval()

0 commit comments

Comments
 (0)