Skip to content

Commit 6ca644d

Browse files
committed
Fix failing tests
1 parent 825f98e commit 6ca644d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc3/distributions/multivariate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def __init__(self, n, p, *args, **kwargs):
598598

599599
p = p / tt.sum(p, axis=-1, keepdims=True)
600600

601-
if len(self.shape) >= 1:
601+
if len(self.shape) > 1:
602602
self.n = tt.shape_padright(n)
603603
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
604604
else:
@@ -607,9 +607,10 @@ def __init__(self, n, p, *args, **kwargs):
607607
self.p = tt.as_tensor_variable(p)
608608

609609
self.mean = self.n * self.p
610-
mode_ind = tt.argmax(self.p, axis=-1, keepdims=True)
611-
mode = tt.zeros_like(self.mean, dtype=self.dtype)
612-
mode = tt.inc_subtensor(mode[..., mode_ind], 1)
610+
mode = tt.cast(tt.round(self.mean), "int32")
611+
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
612+
inc_bool_arr = tt.abs_(diff) > 0
613+
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
613614
self.mode = mode
614615

615616
def _random(self, n, p, size=None, raw_size=None):

0 commit comments

Comments
 (0)