diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index cdc7219b6c..08eec61b10 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -1,3 +1,5 @@ +import warnings + import theano import theano.tensor as tt @@ -14,6 +16,7 @@ __all__ = [ "transform", "stick_breaking", + "stick_breaking2", "logodds", "interval", "log_exp_m1", @@ -510,6 +513,62 @@ def t_stick_breaking(eps): return StickBreaking(eps) +class StickBreaking2(Transform): + """ + Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of real values. + """ + + name = "stickbreaking" + + def __init__(self, eps=None): + if eps is not None: + warnings.warn("The argument `eps` is depricated and will not be used.", + DeprecationWarning) + + def forward(self, x_): + x = x_.T + n = x.shape[0] + lx = tt.log(x) + shift = tt.sum(lx, 0, keepdims=True) / n + y = lx[:-1] - shift + return floatX(y.T) + + def forward_val(self, x_): + x = x_.T + n = x.shape[0] + lx = np.log(x) + shift = np.sum(lx, 0, keepdims=True) / n + y = lx[:-1] - shift + return floatX(y.T) + + def backward(self, y_): + y = y_.T + y = tt.concatenate([y, -tt.sum(y, 0, keepdims=True)]) + # "softmax" with vector support and no deprication warning: + e_y = tt.exp(y - tt.max(y, 0, keepdims=True)) + x = e_y / tt.sum(e_y, 0, keepdims=True) + return floatX(x.T) + + def backward_val(self, y_): + y = y_.T + y = np.concatenate([y, -np.sum(y, 0, keepdims=True)]) + x = np.exp(y)/np.sum(np.exp(y), 0, keepdims=True) + return floatX(x.T) + + def jacobian_det(self, y_): + y = y_.T + Km1 = y.shape[0] + sy = tt.sum(y, 0, keepdims=True) + r = tt.concatenate([y+sy, tt.zeros(sy.shape)]) + # stable according to: http://deeplearning.net/software/theano_versions/0.9.X/NEWS.html + sr = tt.log(tt.sum(tt.exp(r), 0, keepdims=True)) + d = tt.log(Km1) + (Km1*sy) - (Km1*sr) + return tt.sum(d, 0).T + + +stick_breaking2 = StickBreaking2() + + class Circular(ElemwiseTransform): """Transforms a linear space into a circular one. """ diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index 7396e253d7..7267b11fac 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -70,7 +70,7 @@ def check_jacobian_det(transform, domain, computed_ljd(yval), tol) -def test_simplex(): +def test_stickbreaking(): check_vector_transform(tr.stick_breaking, Simplex(2)) check_vector_transform(tr.stick_breaking, Simplex(4)) @@ -78,7 +78,7 @@ def test_simplex(): 3, 2), constructor=tt.dmatrix, test=np.zeros((2, 2))) -def test_simplex_bounds(): +def test_stickbreaking_bounds(): vals = get_values(tr.stick_breaking, Vector(R, 2), tt.dvector, np.array([0, 0])) @@ -90,6 +90,27 @@ def test_simplex_bounds(): R, 2), tt.dvector, np.array([0, 0]), lambda x: x[:-1]) +def test_stickbreaking2(): + check_vector_transform(tr.stick_breaking2, Simplex(2)) + check_vector_transform(tr.stick_breaking2, Simplex(4)) + + check_transform(tr.stick_breaking2, MultiSimplex( + 3, 2), constructor=tt.dmatrix, test=np.zeros((2, 2))) + + +def test_stickbreaking2_bounds(): + vals = get_values(tr.stick_breaking2, Vector(R, 2), + tt.dvector, np.array([0, 0])) + + close_to(vals.sum(axis=1), 1, tol) + close_to_logical(vals > 0, True, tol) + close_to_logical(vals < 1, True, tol) + + check_jacobian_det(tr.stick_breaking2, Vector(R, 2), + tt.dvector, np.array([0, 0]), + lambda x: x[:-1]) + + def test_sum_to_1(): check_vector_transform(tr.sum_to_1, Simplex(2)) check_vector_transform(tr.sum_to_1, Simplex(4))