Skip to content

Commit d01aaf0

Browse files
Max HornJunpeng Lao
authored andcommitted
Converted eps value to floatX fixes test failures of #2515 and lowered precision requirements for test_transforms on float32 machines (#2517)
1 parent b91554a commit d01aaf0

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pymc3/distributions/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ class StickBreaking(Transform):
250250

251251
name = "stickbreaking"
252252

253-
def __init__(self, eps=np.finfo(theano.config.floatX).eps):
253+
def __init__(self, eps=floatX(np.finfo(theano.config.floatX).eps)):
254254
self.eps = eps
255255

256256
def forward(self, x_):
@@ -263,7 +263,7 @@ def forward(self, x_):
263263
k = tt.arange(Km1)[(slice(None), ) + (None, ) * (x.ndim - 1)]
264264
eq_share = logit(1. / (Km1 + 1 - k).astype(str(x_.dtype)))
265265
y = logit(z) - eq_share
266-
return y.T
266+
return floatX(y.T)
267267

268268
def forward_val(self, x, point=None):
269269
return self.forward(x)
@@ -278,7 +278,7 @@ def backward(self, y_):
278278
yu = tt.concatenate([tt.ones(y[:1].shape), 1 - z])
279279
S = tt.extra_ops.cumprod(yu, 0)
280280
x = S * yl
281-
return x.T
281+
return floatX(x.T)
282282

283283
def jacobian_det(self, y_):
284284
y = y_.T

pymc3/tests/test_transforms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from .checks import close_to, close_to_logical
88
from ..theanof import jacobian
99

10-
tol = 1e-7
10+
11+
# some transforms (stick breaking) require additon of small slack in order to be numerically
12+
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
13+
tol = 1e-7 if theano.config.floatX == 'flaot64' else 1e-6
1114

1215

1316
def check_transform_identity(transform, domain, constructor=tt.dscalar, test=0):

0 commit comments

Comments
 (0)