diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index be0df56541..ebdaf3c3e1 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -89,26 +89,48 @@ def log_jac_det(self, value, *inputs): class Ordered(Transform): + """ + Transforms a vector of values into a vector of ordered values. + + Parameters + ---------- + positive: If True, all values are positive. This has better geometry than just chaining with a log transform. + ascending: If True, the values are in ascending order (default). If False, the values are in descending order. + """ + name = "ordered" - def __init__(self, ndim_supp=None): + def __init__(self, ndim_supp=None, positive=False, ascending=True): if ndim_supp is not None: warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning) + self.positive = positive + self.ascending = ascending def backward(self, value, *inputs): - x = pt.zeros(value.shape) - x = pt.set_subtensor(x[..., 0], value[..., 0]) - x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:])) - return pt.cumsum(x, axis=-1) + if self.positive: # Transform both initial value and deltas to be positive + x = pt.exp(value) + else: # Transform only deltas to be positive + x = pt.empty(value.shape) + x = pt.set_subtensor(x[..., 0], value[..., 0]) + x = pt.set_subtensor(x[..., 1:], pt.exp(value[..., 1:])) + x = pt.cumsum(x, axis=-1) # Add deltas cumulatively to initial value + if not self.ascending: + x = x[..., ::-1] + return x def forward(self, value, *inputs): - y = pt.zeros(value.shape) - y = pt.set_subtensor(y[..., 0], value[..., 0]) + if not self.ascending: + value = value[..., ::-1] + y = pt.empty(value.shape) + y = pt.set_subtensor(y[..., 0], pt.log(value[..., 0]) if self.positive else value[..., 0]) y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1])) return y def log_jac_det(self, value, *inputs): - return pt.sum(value[..., 1:], axis=-1) + if self.positive: + return pt.sum(value, axis=-1) + else: + return pt.sum(value[..., 1:], axis=-1) class SumTo1(Transform): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 12d9b438b5..e28052bab9 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -103,7 +103,7 @@ def check_jacobian_det( x = make_comparable(x) if not elemwise: - jac = pt.log(pt.nlinalg.det(jacobian(x, [y]))) + jac = pt.log(pt.abs(pt.nlinalg.det(jacobian(x, [y])))) else: jac = pt.log(pt.abs(pt.diag(jacobian(x, [y])))) @@ -115,7 +115,7 @@ def check_jacobian_det( ) for yval in domain.vals: - assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol) + assert_allclose(actual_ljd(yval), computed_ljd(yval), rtol=tol, atol=tol) def test_simplex(): @@ -281,6 +281,31 @@ def test_ordered(): vals = get_values(tr.ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3))) assert_array_equal(np.diff(vals) >= 0, True) + # Check that positive=True creates positive and still ordered values + vals = get_values(tr.Ordered(positive=True), Vector(R, 3), pt.vector, floatX(np.zeros(3))) + assert_array_equal(vals > 0, True) + assert_array_equal(np.diff(vals) >= 0, True) + + # Check that positive=True and ascending=False creates descending values + vals = get_values( + tr.Ordered(positive=True, ascending=False), Vector(R, 3), pt.vector, floatX(np.zeros(3)) + ) + assert_array_equal(vals > 0, True) + assert_array_equal(np.diff(vals) <= 0, True) + + # Check that forward and backward are still inverses + ord, vals = tr.Ordered(positive=True, ascending=False), np.array([0.3, 0.2, 0.1]) + assert_allclose(vals, ord.backward(ord.forward(vals)).eval()) + + # Check the jacobian for positive=True and ascending=False + check_jacobian_det( + tr.Ordered(positive=True, ascending=False), + Vector(R, 2), + pt.vector, + floatX(np.array([1, 1])), + elemwise=False, + ) + def test_chain_values(): chain_tranf = tr.Chain([tr.logodds, tr.ordered])