Skip to content

Commit 36b1825

Browse files
committed
Fix bug in Bound
1 parent 27fd1b6 commit 36b1825

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

pymc3/distributions/distribution.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -444,17 +444,20 @@ def __init__(self, distribution, lower, upper,
444444
else:
445445
transform = transforms.lowerbound(lower)
446446
default = lower + 1
447+
else:
448+
default = None
447449

448-
# We only change logp and testval for
449-
# discrete distributions
450+
# We don't use transformations for dicrete variables
450451
if issubclass(distribution, Discrete):
451452
transform = None
452-
if default is not None:
453-
default = default.astype(self.dist.default.type())
454453

454+
kwargs['transform'] = transform
455455
self._wrapped = distribution.dist(*args, **kwargs)
456456
self._default = default
457457

458+
if issubclass(distribution, Discrete) and default is not None:
459+
default = default.astype(str(self._wrapped.default().dtype))
460+
458461
if default is None:
459462
defaults = self._wrapped.defaults
460463
for name in defaults:
@@ -470,16 +473,16 @@ def __init__(self, distribution, lower, upper,
470473
transform=self._wrapped.transform)
471474

472475
def _random(self, lower, upper, point=None, size=None):
473-
if lower is None:
474-
lower = -np.inf
475-
if upper is None:
476-
upper = np.inf
477-
478-
samples = np.zeros(size).flatten()
476+
lower = np.asarray(lower)
477+
upper = np.asarray(upper)
478+
if lower.size > 1 or upper.size > 1:
479+
raise ValueError('Drawing samples from distributions with '
480+
'array-valued bounds is not supported.')
481+
samples = np.zeros(size, dtype=self.dtype).flatten()
479482
i, n = 0, len(samples)
480483
while i < len(samples):
481484
sample = self._wrapped.random(point=point, size=n)
482-
select = sample[np.logical_and(sample > lower, sample <= upper)]
485+
select = sample[np.logical_and(sample >= lower, sample <= upper)]
483486
samples[i:(i + len(select))] = select[:]
484487
i += len(select)
485488
n -= len(select)
@@ -489,18 +492,31 @@ def _random(self, lower, upper, point=None, size=None):
489492
return samples
490493

491494
def random(self, point=None, size=None, repeat=None):
492-
lower, upper = draw_values([self.lower, self.upper], point=point)
493-
return generate_samples(self._random, lower, upper, point,
494-
dist_shape=self.shape,
495-
size=size)
495+
if self.lower is None and self.upper is None:
496+
return self._wrapped.random(point=point, size=size)
497+
elif self.lower is not None and self.upper is not None:
498+
lower, upper = draw_values([self.lower, self.upper], point=point)
499+
return generate_samples(self._random, lower, upper, point,
500+
dist_shape=self.shape,
501+
size=size)
502+
elif self.lower is not None:
503+
lower = draw_values([self.lower], point=point)
504+
return generate_samples(self._random, lower, np.inf, point,
505+
dist_shape=self.shape,
506+
size=size)
507+
else:
508+
upper = draw_values([self.upper], point=point)
509+
return generate_samples(self._random, -np.inf, upper, point,
510+
dist_shape=self.shape,
511+
size=size)
496512

497513
def logp(self, value):
498514
logp = self._wrapped.logp(value)
499515
bounds = []
500516
if self.lower is not None:
501-
bounds.append(value > self.lower)
517+
bounds.append(value >= self.lower)
502518
if self.upper is not None:
503-
bounds.append(value < self.upper)
519+
bounds.append(value <= self.upper)
504520
if len(bounds) > 0:
505521
return bound(logp, *bounds)
506522
else:
@@ -516,13 +532,16 @@ class Bound(object):
516532
truncated distributions, use `Bound` in combination with
517533
a `pm.Potential` with the cumulative probability function.
518534
535+
The bounds are inclusive for discrete distributions.
536+
519537
Parameters
520538
----------
521539
distribution : pymc3 distribution
522-
Distribution to be transformed into a bounded distribution
540+
Distribution to be transformed into a bounded distribution.
523541
lower : float or array like, optional
524-
Lower bound of the distribution
542+
Lower bound of the distribution.
525543
upper : float or array like, optional
544+
Upper bound of the distribution.
526545
527546
Example
528547
-------

pymc3/distributions/transforms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
__all__ = ['transform', 'stick_breaking', 'logodds', 'interval',
11-
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
11+
'lowerbound', 'upperbound', 'log', 'sum_to_1', 't_stick_breaking']
1212

1313

1414
class Transform(object):
@@ -33,6 +33,7 @@ def jacobian_det(self, x):
3333
raise NotImplementedError
3434

3535
def apply(self, dist):
36+
# avoid circular import
3637
return TransformedDistribution.dist(dist, self)
3738

3839
def __str__(self):
@@ -90,7 +91,7 @@ def backward(self, x):
9091

9192
def forward(self, x):
9293
return tt.log(x)
93-
94+
9495
def forward_val(self, x, point=None):
9596
return self.forward(x)
9697

@@ -111,7 +112,7 @@ def backward(self, x):
111112

112113
def forward(self, x):
113114
return logit(x)
114-
115+
115116
def forward_val(self, x, point=None):
116117
return self.forward(x)
117118

@@ -324,6 +325,6 @@ def forward(self, y):
324325

325326
def forward_val(self, x, point=None):
326327
return self.forward(x)
327-
328+
328329
def jacobian_det(self, y):
329330
return tt.sum(y[self.diag_idxs])

pymc3/tests/test_distributions.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -855,26 +855,44 @@ def ref_pdf(value):
855855

856856

857857
def test_bound():
858+
np.random.seed(42)
858859
UnboundNormal = Bound(Normal)
859860
dist = UnboundNormal.dist(mu=0, sd=1)
860861
assert dist.transform is None
861862
assert dist.default() == 0.
863+
assert isinstance(dist.random(), np.ndarray)
864+
862865
LowerNormal = Bound(Normal, lower=1)
863866
dist = LowerNormal.dist(mu=0, sd=1)
864867
assert dist.logp(0).eval() == -np.inf
865868
assert dist.default() > 1
869+
assert dist.transform is not None
870+
assert np.all(dist.random() > 1)
871+
866872
UpperNormal = Bound(Normal, upper=-1)
867873
dist = UpperNormal.dist(mu=0, sd=1)
868-
assert dist.logp(-1).eval() == -np.inf
874+
assert dist.logp(-0.5).eval() == -np.inf
869875
assert dist.default() < -1
876+
assert dist.transform is not None
877+
assert np.all(dist.random() < -1)
878+
870879
ArrayNormal = Bound(Normal, lower=[1, 2], upper=[2, 3])
871-
dist = ArrayNormal.dist(mu=0, sd=1)
872-
assert_equal(dist.logp([1, 2]).eval(), -np.array([np.inf, np.inf]))
880+
dist = ArrayNormal.dist(mu=0, sd=1, shape=2)
881+
assert_equal(dist.logp([0.5, 3.5]).eval(), -np.array([np.inf, np.inf]))
873882
assert_equal(dist.default(), np.array([1.5, 2.5]))
883+
assert dist.transform is not None
884+
with pytest.raises(ValueError) as err:
885+
dist.random()
886+
err.match('Drawing samples from distributions with array-valued')
887+
874888
with Model():
875889
a = ArrayNormal('c', shape=2)
876890
assert_equal(a.tag.test_value, np.array([1.5, 2.5]))
877891

892+
rand = Bound(Binomial, lower=10).dist(n=20, p=0.3).random()
893+
assert rand.dtype in [np.int16, np.int32, np.int64]
894+
assert rand >= 10
895+
878896

879897
def test_repr_latex_():
880898
with Model():

0 commit comments

Comments
 (0)