Skip to content

Commit c5f01a0

Browse files
Remove Dirichlet distribution type restrictions
Closes pymc-devs#3999.
1 parent b2c682e commit c5f01a0

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -486,23 +486,9 @@ class Dirichlet(Continuous):
486486

487487
def __init__(self, a, transform=transforms.stick_breaking,
488488
*args, **kwargs):
489-
490-
if not isinstance(a, pm.model.TensorVariable):
491-
if not isinstance(a, list) and not isinstance(a, np.ndarray):
492-
raise TypeError(
493-
'The vector of concentration parameters (a) must be a python list '
494-
'or numpy array.')
495-
a = np.array(a)
496-
if (a <= 0).any():
497-
raise ValueError("All concentration parameters (a) must be > 0.")
498-
499-
shape = np.atleast_1d(a.shape)[-1]
500-
501-
kwargs.setdefault("shape", shape)
502489
super().__init__(transform=transform, *args, **kwargs)
503490

504491
self.size_prefix = tuple(self.shape[:-1])
505-
self.k = tt.as_tensor_variable(shape)
506492
self.a = a = tt.as_tensor_variable(a)
507493
self.mean = a / tt.sum(a)
508494

@@ -569,14 +555,13 @@ def logp(self, value):
569555
-------
570556
TensorVariable
571557
"""
572-
k = self.k
573558
a = self.a
574559

575560
# only defined for sum(value) == 1
576561
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
577562
+ gammaln(tt.sum(a, axis=-1)),
578563
tt.all(value >= 0), tt.all(value <= 1),
579-
k > 1, tt.all(a > 0),
564+
np.logical_not(a.broadcastable), tt.all(a > 0),
580565
broadcast_conditions=False)
581566

582567
def _repr_latex_(self, name=None, dist=None):

0 commit comments

Comments
 (0)