Skip to content

Commit 825f98e

Browse files
committed
Allow Multinomial to work with batches of n and p that have more than 2 dimensions
1 parent 4fff38b commit 825f98e

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

pymc3/distributions/multivariate.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,12 +597,8 @@ def __init__(self, n, p, *args, **kwargs):
597597
super().__init__(*args, **kwargs)
598598

599599
p = p / tt.sum(p, axis=-1, keepdims=True)
600-
n = np.squeeze(n) # works also if n is a tensor
601600

602-
if len(self.shape) > 1:
603-
self.n = tt.shape_padright(n)
604-
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
605-
elif n.ndim == 1:
601+
if len(self.shape) >= 1:
606602
self.n = tt.shape_padright(n)
607603
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
608604
else:
@@ -611,10 +607,9 @@ def __init__(self, n, p, *args, **kwargs):
611607
self.p = tt.as_tensor_variable(p)
612608

613609
self.mean = self.n * self.p
614-
mode = tt.cast(tt.round(self.mean), "int32")
615-
diff = self.n - tt.sum(mode, axis=-1, keepdims=True)
616-
inc_bool_arr = tt.abs_(diff) > 0
617-
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])
610+
mode_ind = tt.argmax(self.p, axis=-1, keepdims=True)
611+
mode = tt.zeros_like(self.mean, dtype=self.dtype)
612+
mode = tt.inc_subtensor(mode[..., mode_ind], 1)
618613
self.mode = mode
619614

620615
def _random(self, n, p, size=None, raw_size=None):

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,28 @@ def test_multinomial_vec_2d_p(self):
14471447
decimal=4,
14481448
)
14491449

1450+
def test_batch_multinomial(self):
1451+
n = 10
1452+
vals = np.zeros((4, 5, 3))
1453+
p = np.zeros_like(vals)
1454+
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
1455+
np.put_along_axis(vals, inds, n, axis=-1)
1456+
np.put_along_axis(p, inds, 1, axis=-1)
1457+
1458+
dist = Multinomial.dist(n=n, p=p, shape=vals.shape)
1459+
value = tt.tensor3()
1460+
value.tag.test_value = np.zeros_like(vals)
1461+
logp = tt.exp(dist.logp(value))
1462+
f = theano.function(inputs=[value], outputs=logp)
1463+
assert_almost_equal(
1464+
f(vals),
1465+
np.ones(vals.shape[:-1] + (1,)),
1466+
decimal=select_by_precision(float64=6, float32=3),
1467+
)
1468+
1469+
sample = dist.random(size=2)
1470+
assert_allclose(sample, np.stack([vals, vals], axis=0))
1471+
14501472
def test_categorical_bounds(self):
14511473
with Model():
14521474
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))

0 commit comments

Comments
 (0)