@@ -486,23 +486,9 @@ class Dirichlet(Continuous):
486
486
487
487
def __init__ (self , a , transform = transforms .stick_breaking ,
488
488
* args , ** kwargs ):
489
-
490
- if not isinstance (a , pm .model .TensorVariable ):
491
- if not isinstance (a , list ) and not isinstance (a , np .ndarray ):
492
- raise TypeError (
493
- 'The vector of concentration parameters (a) must be a python list '
494
- 'or numpy array.' )
495
- a = np .array (a )
496
- if (a <= 0 ).any ():
497
- raise ValueError ("All concentration parameters (a) must be > 0." )
498
-
499
- shape = np .atleast_1d (a .shape )[- 1 ]
500
-
501
- kwargs .setdefault ("shape" , shape )
502
489
super ().__init__ (transform = transform , * args , ** kwargs )
503
490
504
491
self .size_prefix = tuple (self .shape [:- 1 ])
505
- self .k = tt .as_tensor_variable (shape )
506
492
self .a = a = tt .as_tensor_variable (a )
507
493
self .mean = a / tt .sum (a )
508
494
@@ -569,14 +555,13 @@ def logp(self, value):
569
555
-------
570
556
TensorVariable
571
557
"""
572
- k = self .k
573
558
a = self .a
574
559
575
560
# only defined for sum(value) == 1
576
561
return bound (tt .sum (logpow (value , a - 1 ) - gammaln (a ), axis = - 1 )
577
562
+ gammaln (tt .sum (a , axis = - 1 )),
578
563
tt .all (value >= 0 ), tt .all (value <= 1 ),
579
- k > 1 , tt .all (a > 0 ),
564
+ np . logical_not ( a . broadcastable ) , tt .all (a > 0 ),
580
565
broadcast_conditions = False )
581
566
582
567
def _repr_latex_ (self , name = None , dist = None ):
0 commit comments