Skip to content

Commit d8bfe93

Browse files
authored
Make Multinomial robust against batches (#4169)
* Allow Multinomial to work with batches of n and p that have more than 2 dimensions * Fix failing tests * Fix the float32 errors * Added line to release notes
1 parent b31b42a commit d8bfe93

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
99
- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)).
1010
- Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129)
11+
- Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169)
1112

1213
### Documentation
1314

pymc3/distributions/multivariate.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,14 +597,10 @@ def __init__(self, n, p, *args, **kwargs):
597597
super().__init__(*args, **kwargs)
598598

599599
p = p / tt.sum(p, axis=-1, keepdims=True)
600-
n = np.squeeze(n) # works also if n is a tensor
601600

602601
if len(self.shape) > 1:
603602
self.n = tt.shape_padright(n)
604603
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
605-
elif n.ndim == 1:
606-
self.n = tt.shape_padright(n)
607-
self.p = p if p.ndim > 1 else tt.shape_padleft(p)
608604
else:
609605
# n is a scalar, p is a 1d array
610606
self.n = tt.as_tensor_variable(n)

pymc3/tests/test_distributions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,28 @@ def test_multinomial_vec_2d_p(self):
14471447
decimal=4,
14481448
)
14491449

1450+
def test_batch_multinomial(self):
1451+
n = 10
1452+
vals = np.zeros((4, 5, 3), dtype="int32")
1453+
p = np.zeros_like(vals, dtype=theano.config.floatX)
1454+
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
1455+
np.put_along_axis(vals, inds, n, axis=-1)
1456+
np.put_along_axis(p, inds, 1, axis=-1)
1457+
1458+
dist = Multinomial.dist(n=n, p=p, shape=vals.shape)
1459+
value = tt.tensor3(dtype="int32")
1460+
value.tag.test_value = np.zeros_like(vals, dtype="int32")
1461+
logp = tt.exp(dist.logp(value))
1462+
f = theano.function(inputs=[value], outputs=logp)
1463+
assert_almost_equal(
1464+
f(vals),
1465+
np.ones(vals.shape[:-1] + (1,)),
1466+
decimal=select_by_precision(float64=6, float32=3),
1467+
)
1468+
1469+
sample = dist.random(size=2)
1470+
assert_allclose(sample, np.stack([vals, vals], axis=0))
1471+
14501472
def test_categorical_bounds(self):
14511473
with Model():
14521474
x = Categorical("x", p=np.array([0.2, 0.3, 0.5]))

0 commit comments

Comments
 (0)