Skip to content

Commit e6335e1

Browse files
committed
Fix Dirichlet.logp by checking number of categories > 1 only at event dims
1 parent b6660f9 commit e6335e1

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def logp(self, value):
522522
tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(tt.sum(a, axis=-1)),
523523
tt.all(value >= 0),
524524
tt.all(value <= 1),
525-
np.logical_not(a.broadcastable),
525+
np.logical_not(a.broadcastable[-1]),
526526
tt.all(a > 0),
527527
broadcast_conditions=False,
528528
)

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,11 @@ def test_lkj(self, x, eta, n, lp):
16961696
def test_dirichlet(self, n):
16971697
self.pymc3_matches_scipy(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
16981698

1699+
def test_dirichlet_with_unit_batch_shape(self):
1700+
with pm.Model() as model:
1701+
a = pm.Dirichlet("a", a=np.ones((1, 2)))
1702+
np.isfinite(model.check_test_point()[0])
1703+
16991704
def test_dirichlet_shape(self):
17001705
a = tt.as_tensor_variable(np.r_[1, 2])
17011706
with pytest.warns(DeprecationWarning):

0 commit comments

Comments
 (0)