We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 825f98e commit 6ca644dCopy full SHA for 6ca644d
pymc3/distributions/multivariate.py
@@ -598,7 +598,7 @@ def __init__(self, n, p, *args, **kwargs):
598
599
p = p / tt.sum(p, axis=-1, keepdims=True)
600
601
- if len(self.shape) >= 1:
+ if len(self.shape) > 1:
602
self.n = tt.shape_padright(n)
603
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
604
else:
@@ -607,9 +607,10 @@ def __init__(self, n, p, *args, **kwargs):
607
self.p = tt.as_tensor_variable(p)
608
609
self.mean = self.n * self.p
610
- mode_ind = tt.argmax(self.p, axis=-1, keepdims=True)
611
- mode = tt.zeros_like(self.mean, dtype=self.dtype)
612
- mode = tt.inc_subtensor(mode[..., mode_ind], 1)
+ mode = tt.cast(tt.round(self.mean), "int32")
+ diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
+ inc_bool_arr = tt.abs_(diff) > 0
613
+ mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
614
self.mode = mode
615
616
def _random(self, n, p, size=None, raw_size=None):
0 commit comments