Skip to content

Commit 27fd1b6

Browse files
committed
Allow arrays in pm.Bound
1 parent fd1e5ba commit 27fd1b6

File tree

2 files changed

+101
-35
lines changed

2 files changed

+101
-35
lines changed

pymc3/distributions/distribution.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..vartypes import string_types
99
from .dist_math import bound
1010
# To avoid circular import for transform below
11-
import pymc3 as pm
11+
from pymc3.distributions import transforms
1212

1313
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Bound',
1414
'Discrete', 'NoDistribution', 'TensorType', 'draw_values']
@@ -416,42 +416,69 @@ class Bounded(Distribution):
416416
See pymc3.distributions.transforms for more information.
417417
"""
418418

419-
def __init__(self, distribution, lower, upper, transform='infer', *args, **kwargs):
420-
self.dist = distribution.dist(*args, **kwargs)
419+
def __init__(self, distribution, lower, upper,
420+
transform='infer', *args, **kwargs):
421+
if lower == -np.inf:
422+
lower = None
423+
if upper == np.inf:
424+
upper = None
421425

422-
self.__dict__.update(self.dist.__dict__)
423-
self.__dict__.update(locals())
426+
if lower is not None:
427+
lower = tt.as_tensor_variable(lower)
428+
if upper is not None:
429+
upper = tt.as_tensor_variable(upper)
424430

425-
if hasattr(self.dist, 'mode'):
426-
self.mode = self.dist.mode
431+
self.lower = lower
432+
self.upper = upper
427433

428434
if transform == 'infer':
435+
if lower is None and upper is None:
436+
transform = None
437+
default = None
438+
elif lower is not None and upper is not None:
439+
transform = transforms.interval(lower, upper)
440+
default = 0.5 * (lower + upper)
441+
elif upper is not None:
442+
transform = transforms.upperbound(upper)
443+
default = upper - 1
444+
else:
445+
transform = transforms.lowerbound(lower)
446+
default = lower + 1
429447

430-
default = self.dist.default()
431-
432-
if not np.isinf(lower) and not np.isinf(upper):
433-
self.transform = pm.distributions.transforms.interval(lower, upper)
434-
if default <= lower or default >= upper:
435-
self.testval = 0.5 * (upper + lower)
448+
# We only change logp and testval for
449+
# discrete distributions
450+
if issubclass(distribution, Discrete):
451+
transform = None
452+
if default is not None:
453+
default = default.astype(self.dist.default.type())
436454

437-
if not np.isinf(lower) and np.isinf(upper):
438-
self.transform = pm.distributions.transforms.lowerbound(lower)
439-
if default <= lower:
440-
self.testval = lower + 1
455+
self._wrapped = distribution.dist(*args, **kwargs)
456+
self._default = default
441457

442-
if np.isinf(lower) and not np.isinf(upper):
443-
self.transform = pm.distributions.transforms.upperbound(upper)
444-
if default >= upper:
445-
self.testval = upper - 1
458+
if default is None:
459+
defaults = self._wrapped.defaults
460+
for name in defaults:
461+
setattr(self, name, getattr(self._wrapped, name))
462+
else:
463+
defaults = ('_default',)
446464

447-
if issubclass(distribution, Discrete):
448-
self.transform = None
465+
super(Bounded, self).__init__(
466+
shape=self._wrapped.shape,
467+
dtype=self._wrapped.dtype,
468+
testval=self._wrapped.testval,
469+
defaults=defaults,
470+
transform=self._wrapped.transform)
449471

450472
def _random(self, lower, upper, point=None, size=None):
473+
if lower is None:
474+
lower = -np.inf
475+
if upper is None:
476+
upper = np.inf
477+
451478
samples = np.zeros(size).flatten()
452479
i, n = 0, len(samples)
453480
while i < len(samples):
454-
sample = self.dist.random(point=point, size=n)
481+
sample = self._wrapped.random(point=point, size=n)
455482
select = sample[np.logical_and(sample > lower, sample <= upper)]
456483
samples[i:(i + len(select))] = select[:]
457484
i += len(select)
@@ -468,21 +495,34 @@ def random(self, point=None, size=None, repeat=None):
468495
size=size)
469496

470497
def logp(self, value):
471-
return bound(self.dist.logp(value),
472-
value >= self.lower, value <= self.upper)
498+
logp = self._wrapped.logp(value)
499+
bounds = []
500+
if self.lower is not None:
501+
bounds.append(value > self.lower)
502+
if self.upper is not None:
503+
bounds.append(value < self.upper)
504+
if len(bounds) > 0:
505+
return bound(logp, *bounds)
506+
else:
507+
return logp
473508

474509

475510
class Bound(object):
476511
R"""
477-
Creates a new upper, lower or upper+lower bounded distribution
512+
Create a new upper, lower or upper+lower bounded distribution.
513+
514+
The resulting distribution is not normalized anymore. This
515+
is usually fine if the bounds are constants. If you need
516+
truncated distributions, use `Bound` in combination with
517+
a `pm.Potential` with the cumulative probability function.
478518
479519
Parameters
480520
----------
481521
distribution : pymc3 distribution
482522
Distribution to be transformed into a bounded distribution
483-
lower : float (optional)
523+
lower : float or array like, optional
484524
Lower bound of the distribution
485-
upper : float (optional)
525+
upper : float or array like, optional
486526
487527
Example
488528
-------
@@ -499,7 +539,7 @@ class Bound(object):
499539
'par3', mu=0.0, sd=1.0, testval=1.0)
500540
"""
501541

502-
def __init__(self, distribution, lower=-np.inf, upper=np.inf):
542+
def __init__(self, distribution, lower=None, upper=None):
503543
self.distribution = distribution
504544
self.lower = lower
505545
self.upper = upper

pymc3/tests/test_distributions.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..distributions import continuous
1919
from pymc3.theanof import floatX
2020
from numpy import array, inf, log, exp
21-
from numpy.testing import assert_almost_equal
21+
from numpy.testing import assert_almost_equal, assert_equal
2222
import numpy.random as nr
2323
import numpy as np
2424
import pytest
@@ -569,7 +569,8 @@ def test_binomial(self):
569569
self.pymc3_matches_scipy(Binomial, Nat, {'n': NatSmall, 'p': Unit},
570570
lambda value, n, p: sp.binom.logpmf(value, n, p))
571571

572-
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") # Too lazy to propagate decimal parameter through the whole chain of deps
572+
# Too lazy to propagate decimal parameter through the whole chain of deps
573+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
573574
def test_beta_binomial(self):
574575
self.checkd(BetaBinomial, Nat, {'alpha': Rplus, 'beta': Rplus, 'n': NatSmall})
575576

@@ -597,16 +598,19 @@ def test_constantdist(self):
597598
self.pymc3_matches_scipy(Constant, I, {'c': I},
598599
lambda value, c: np.log(c == value))
599600

600-
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") # Too lazy to propagate decimal parameter through the whole chain of deps
601+
# Too lazy to propagate decimal parameter through the whole chain of deps
602+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
601603
def test_zeroinflatedpoisson(self):
602604
self.checkd(ZeroInflatedPoisson, Nat, {'theta': Rplus, 'psi': Unit})
603605

604-
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") # Too lazy to propagate decimal parameter through the whole chain of deps
606+
# Too lazy to propagate decimal parameter through the whole chain of deps
607+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
605608
def test_zeroinflatednegativebinomial(self):
606609
self.checkd(ZeroInflatedNegativeBinomial, Nat,
607610
{'mu': Rplusbig, 'alpha': Rplusbig, 'psi': Unit})
608611

609-
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") # Too lazy to propagate decimal parameter through the whole chain of deps
612+
# Too lazy to propagate decimal parameter through the whole chain of deps
613+
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
610614
def test_zeroinflatedbinomial(self):
611615
self.checkd(ZeroInflatedBinomial, Nat,
612616
{'n': NatSmall, 'p': Unit, 'psi': Unit})
@@ -850,6 +854,28 @@ def ref_pdf(value):
850854
self.pymc3_matches_scipy(TestedInterpolated, R, {}, ref_pdf)
851855

852856

857+
def test_bound():
858+
UnboundNormal = Bound(Normal)
859+
dist = UnboundNormal.dist(mu=0, sd=1)
860+
assert dist.transform is None
861+
assert dist.default() == 0.
862+
LowerNormal = Bound(Normal, lower=1)
863+
dist = LowerNormal.dist(mu=0, sd=1)
864+
assert dist.logp(0).eval() == -np.inf
865+
assert dist.default() > 1
866+
UpperNormal = Bound(Normal, upper=-1)
867+
dist = UpperNormal.dist(mu=0, sd=1)
868+
assert dist.logp(-1).eval() == -np.inf
869+
assert dist.default() < -1
870+
ArrayNormal = Bound(Normal, lower=[1, 2], upper=[2, 3])
871+
dist = ArrayNormal.dist(mu=0, sd=1)
872+
assert_equal(dist.logp([1, 2]).eval(), -np.array([np.inf, np.inf]))
873+
assert_equal(dist.default(), np.array([1.5, 2.5]))
874+
with Model():
875+
a = ArrayNormal('c', shape=2)
876+
assert_equal(a.tag.test_value, np.array([1.5, 2.5]))
877+
878+
853879
def test_repr_latex_():
854880
with Model():
855881
x0 = Binomial('Discrete', p=.5, n=10)

0 commit comments

Comments
 (0)