diff --git a/pymc3/distributions/multivariate.py b/pymc3/distributions/multivariate.py index 3b7196682d..d3c76d5df3 100755 --- a/pymc3/distributions/multivariate.py +++ b/pymc3/distributions/multivariate.py @@ -23,6 +23,7 @@ from scipy import stats, linalg +from theano.gof.op import get_test_value from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh from theano.tensor.slinalg import Cholesky import pymc3 as pm @@ -487,22 +488,23 @@ class Dirichlet(Continuous): def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs): - if not isinstance(a, pm.model.TensorVariable): - if not isinstance(a, list) and not isinstance(a, np.ndarray): - raise TypeError( - 'The vector of concentration parameters (a) must be a python list ' - 'or numpy array.') - a = np.array(a) - if (a <= 0).any(): - raise ValueError("All concentration parameters (a) must be > 0.") - - shape = np.atleast_1d(a.shape)[-1] + if kwargs.get('shape') is None: + warnings.warn( + ( + "Shape not explicitly set. " + "Please, set the value using the `shape` keyword argument. " + "Using the test value to infer the shape." + ), + DeprecationWarning + ) + try: + kwargs['shape'] = get_test_value(tt.shape(a)) + except AttributeError: + pass - kwargs.setdefault("shape", shape) super().__init__(transform=transform, *args, **kwargs) self.size_prefix = tuple(self.shape[:-1]) - self.k = tt.as_tensor_variable(shape) self.a = a = tt.as_tensor_variable(a) self.mean = a / tt.sum(a) @@ -569,14 +571,13 @@ def logp(self, value): ------- TensorVariable """ - k = self.k a = self.a # only defined for sum(value) == 1 return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)), tt.all(value >= 0), tt.all(value <= 1), - k > 1, tt.all(a > 0), + np.logical_not(a.broadcastable), tt.all(a > 0), broadcast_conditions=False) def _repr_latex_(self, name=None, dist=None): diff --git a/pymc3/tests/test_dist_math.py b/pymc3/tests/test_dist_math.py index f54f91bc2e..65661afd47 100644 --- a/pymc3/tests/test_dist_math.py +++ b/pymc3/tests/test_dist_math.py @@ -126,11 +126,11 @@ def test_multinomial_bound(): n = x.sum() with pm.Model() as modelA: - p_a = pm.Dirichlet('p', floatX(np.ones(2))) + p_a = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,)) MultinomialA('x', n, p_a, observed=x) with pm.Model() as modelB: - p_b = pm.Dirichlet('p', floatX(np.ones(2))) + p_b = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,)) MultinomialB('x', n, p_b, observed=x) assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}), diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index ba59447d77..2897211e9b 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -1328,17 +1328,14 @@ def test_dirichlet(self, n): Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf ) - @pytest.mark.parametrize("n", [3, 4]) - def test_dirichlet_init_fail(self, n): - with Model(): - with pytest.raises( - ValueError, match=r"All concentration parameters \(a\) must be > 0." - ): - _ = Dirichlet("x", a=np.zeros(n), shape=n) - with pytest.raises( - ValueError, match=r"All concentration parameters \(a\) must be > 0." - ): - _ = Dirichlet("x", a=np.array([-1.0] * n), shape=n) + def test_dirichlet_shape(self): + a = tt.as_tensor_variable(np.r_[1, 2]) + with pytest.warns(DeprecationWarning): + dir_rv = Dirichlet.dist(a) + assert dir_rv.shape == (2,) + + with pytest.warns(DeprecationWarning), theano.change_flags(compute_test_value="ignore"): + dir_rv = Dirichlet.dist(tt.vector()) def test_dirichlet_2D(self): self.pymc3_matches_scipy( diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 17e7c29140..b4d4338f1e 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -912,7 +912,7 @@ def test_mixture_random_shape(): nr.poisson(9, size=10)]) with pm.Model() as m: comp0 = pm.Poisson.dist(mu=np.ones(2)) - w0 = pm.Dirichlet('w0', a=np.ones(2)) + w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,)) like0 = pm.Mixture('like0', w=w0, comp_dists=comp0, @@ -920,7 +920,7 @@ def test_mixture_random_shape(): comp1 = pm.Poisson.dist(mu=np.ones((20, 2)), shape=(20, 2)) - w1 = pm.Dirichlet('w1', a=np.ones(2)) + w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,)) like1 = pm.Mixture('like1', w=w1, comp_dists=comp1, @@ -967,7 +967,7 @@ def test_mixture_random_shape_fast(): nr.poisson(9, size=10)]) with pm.Model() as m: comp0 = pm.Poisson.dist(mu=np.ones(2)) - w0 = pm.Dirichlet('w0', a=np.ones(2)) + w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,)) like0 = pm.Mixture('like0', w=w0, comp_dists=comp0, @@ -975,7 +975,7 @@ def test_mixture_random_shape_fast(): comp1 = pm.Poisson.dist(mu=np.ones((20, 2)), shape=(20, 2)) - w1 = pm.Dirichlet('w1', a=np.ones(2)) + w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,)) like1 = pm.Mixture('like1', w=w1, comp_dists=comp1, diff --git a/pymc3/tests/test_mixture.py b/pymc3/tests/test_mixture.py index 2547b09dda..308a1aa2e5 100644 --- a/pymc3/tests/test_mixture.py +++ b/pymc3/tests/test_mixture.py @@ -79,7 +79,7 @@ def test_dimensions(self): def test_mixture_list_of_normals(self): with Model() as model: - w = Dirichlet('w', floatX(np.ones_like(self.norm_w))) + w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size) mu = Normal('mu', 0., 10., shape=self.norm_w.size) tau = Gamma('tau', 1., 1., shape=self.norm_w.size) Mixture('x_obs', w, @@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self): def test_normal_mixture(self): with Model() as model: - w = Dirichlet('w', floatX(np.ones_like(self.norm_w))) + w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size) mu = Normal('mu', 0., 10., shape=self.norm_w.size) tau = Gamma('tau', 1., 1., shape=self.norm_w.size) NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x) @@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp): with Model() as model0: mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) - ws = Dirichlet('ws', np.ones(ncomp)) + ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,)) mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape) obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd, @@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp): with Model() as model1: mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) - ws = Dirichlet('ws', np.ones(ncomp)) + ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,)) comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i], shape=nd) for i in range(ncomp)] @@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp): # comp_dists. mus = Normal('mus', shape=comp_shape) taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape) - ws = Dirichlet('ws', np.ones(ncomp)) + ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,)) if len(nd) > 1: if nd[-1] != ncomp: with pytest.raises(ValueError): @@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp): def test_poisson_mixture(self): with Model() as model: - w = Dirichlet('w', floatX(np.ones_like(self.pois_w))) + w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape) mu = Gamma('mu', 1., 1., shape=self.pois_w.size) Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x) step = Metropolis() @@ -224,7 +224,7 @@ def test_poisson_mixture(self): def test_mixture_list_of_poissons(self): with Model() as model: - w = Dirichlet('w', floatX(np.ones_like(self.pois_w))) + w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape) mu = Gamma('mu', 1., 1., shape=self.pois_w.size) Mixture('x_obs', w, [Poisson.dist(mu[0]), Poisson.dist(mu[1])], @@ -247,7 +247,7 @@ def test_mixture_of_mvn(self): cov2 = np.diag([2.5, 3.5]) obs = np.asarray([[.5, .5], mu1, mu2]) with Model() as model: - w = Dirichlet('w', floatX(np.ones(2)), transform=None) + w = Dirichlet('w', floatX(np.ones(2)), transform=None, shape=(2,)) mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1) mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2) y = Mixture('x_obs', w, [mvncomp1, mvncomp2], @@ -291,13 +291,13 @@ def test_mixture_of_mixture(self): sigma=1, shape=nbr) # weight vector for the mixtures - g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None) - l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None) + g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,)) + l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,)) # mixture components g_mix = Mixture.dist(w=g_w, comp_dists=g_comp) l_mix = Mixture.dist(w=l_w, comp_dists=l_comp) # mixture of mixtures - mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None) + mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None, shape=(2,)) mix = Mixture('mix', w=mix_w, comp_dists=[g_mix, l_mix], observed=np.exp(self.norm_x)) @@ -378,7 +378,7 @@ def build_toy_dataset(N, K): X, y = build_toy_dataset(N, K) with pm.Model() as model: - pi = pm.Dirichlet('pi', np.ones(K)) + pi = pm.Dirichlet('pi', np.ones(K), shape=(K,)) comp_dist = [] mu = []