Skip to content

Commit 839827a

Browse files
committed
Add test for theano variable bounds in pm.Bound
1 parent d63bdd6 commit 839827a

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,10 +889,32 @@ def test_bound():
889889
a = ArrayNormal('c', shape=2)
890890
assert_equal(a.tag.test_value, np.array([1.5, 2.5]))
891891

892+
lower = tt.vector('lower')
893+
lower.tag.test_value = np.array([1, 2]).astype(theano.config.floatX)
894+
upper = 3
895+
ArrayNormal = Bound(Normal, lower=lower, upper=upper)
896+
dist = ArrayNormal.dist(mu=0, sd=1, shape=2)
897+
logp = dist.logp([0.5, 3.5]).eval({lower: lower.tag.test_value})
898+
assert_equal(logp, -np.array([np.inf, np.inf]))
899+
assert_equal(dist.default(), np.array([2, 2.5]))
900+
assert dist.transform is not None
901+
902+
with Model():
903+
a = ArrayNormal('c', shape=2)
904+
assert_equal(a.tag.test_value, np.array([2, 2.5]))
905+
892906
rand = Bound(Binomial, lower=10).dist(n=20, p=0.3).random()
893907
assert rand.dtype in [np.int16, np.int32, np.int64]
894908
assert rand >= 10
895909

910+
rand = Bound(Binomial, upper=10).dist(n=20, p=0.8).random()
911+
assert rand.dtype in [np.int16, np.int32, np.int64]
912+
assert rand <= 10
913+
914+
rand = Bound(Binomial, lower=5, upper=8).dist(n=10, p=0.6).random()
915+
assert rand.dtype in [np.int16, np.int32, np.int64]
916+
assert rand >= 5 and rand <= 8
917+
896918

897919
def test_repr_latex_():
898920
with Model():

0 commit comments

Comments
 (0)