Skip to content

Commit c0e7cd5

Browse files
committed
Add flag to disable bounds check.
1 parent 0402aab commit c0e7cd5

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pymc3/distributions/dist_math.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from pymc3.distributions.shape_utils import to_tuple
3535
from pymc3.distributions.special import gammaln
36+
from pymc3.model import modelcontext
3637
from pymc3.theanof import floatX
3738

3839
f = floatX
@@ -67,6 +68,13 @@ def bound(logp, *conditions, **kwargs):
6768
-------
6869
logp with elements set to -inf where any condition is False
6970
"""
71+
72+
# If called inside a model context, see if bounds check is disabled
73+
model = modelcontext(kwargs.get("model"))
74+
if model is not None:
75+
if model.disable_bounds_check:
76+
return logp
77+
7078
broadcast_conditions = kwargs.get("broadcast_conditions", True)
7179

7280
if broadcast_conditions:

pymc3/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,11 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
809809
temporarily in the model context. See the documentation
810810
of theano for a complete list. Set config key
811811
``compute_test_value`` to `raise` if it is None.
812+
disable_bounds_check: bool
813+
Disable checks that ensure that input parameters to distributions
814+
are in a valid range. If your model is built in a way where you
815+
know your parameters can only take on valid values you can disable
816+
this for increased speed.
812817
813818
Examples
814819
--------
@@ -895,11 +900,14 @@ def __new__(cls, *args, **kwargs):
895900
instance._theano_config = theano_config
896901
return instance
897902

898-
def __init__(self, name="", model=None, theano_config=None, coords=None):
903+
def __init__(
904+
self, name="", model=None, theano_config=None, coords=None, disable_bounds_check=False
905+
):
899906
self.name = name
900907
self.coords = {}
901908
self.RV_dims = {}
902909
self.add_coords(coords)
910+
self.disable_bounds_check = disable_bounds_check
903911

904912
if self.parent is not None:
905913
self.named_vars = treedict(parent=self.parent.named_vars)

0 commit comments

Comments
 (0)