Skip to content

Commit 248131d

Browse files
committed
Remove legacy docstrings Multinomial restriction on dimensionality of n and p
Refactor vectorized logp tests
1 parent eb925bc commit 248131d

File tree

2 files changed

+31
-110
lines changed

2 files changed

+31
-110
lines changed

pymc/distributions/multivariate.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,13 +507,11 @@ class Multinomial(Discrete):
507507
508508
Parameters
509509
----------
510-
n: int or array
511-
Number of trials (n > 0). If n is an array its shape must be (N,) with
512-
N = p.shape[0]
513-
p: one- or two-dimensional array
514-
Probability of each one of the different outcomes. Elements must
515-
be non-negative and sum to 1 along the last axis. They will be
516-
automatically rescaled otherwise.
510+
n: int
511+
Number of trials (n > 0)
512+
p: vector
513+
Probability of each one of the different outcomes. Elements must be non-negative
514+
and sum to 1 along the last axis. They will be automatically rescaled otherwise.
517515
"""
518516
rv_op = multinomial
519517

pymc/tests/test_distributions.py

Lines changed: 26 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def polyagamma_cdf(*args, **kwargs):
4848
from aesara.tensor.random.op import RandomVariable
4949
from aesara.tensor.var import TensorVariable
5050
from numpy import array, inf, log
51-
from numpy.testing import assert_allclose, assert_almost_equal, assert_equal
51+
from numpy.testing import assert_almost_equal, assert_equal
5252
from scipy import integrate
5353
from scipy.special import erf, gammaln, logit
5454

@@ -327,16 +327,6 @@ def f3(a, b, c):
327327
raise ValueError("Dont know how to integrate shape: " + str(shape))
328328

329329

330-
def multinomial_logpdf(value, n, p):
331-
if value.sum() == n and (0 <= value).all() and (value <= n).all():
332-
logpdf = scipy.special.gammaln(n + 1)
333-
logpdf -= scipy.special.gammaln(value + 1).sum()
334-
logpdf += logpow(p, value).sum()
335-
return logpdf
336-
else:
337-
return -inf
338-
339-
340330
def _dirichlet_multinomial_logpmf(value, n, a):
341331
if value.sum() == n and (0 <= value).all() and (value <= n).all():
342332
sum_a = a.sum()
@@ -2157,7 +2147,10 @@ def test_dirichlet_2D(self):
21572147
@pytest.mark.parametrize("n", [2, 3])
21582148
def test_multinomial(self, n):
21592149
self.check_logp(
2160-
Multinomial, Vector(Nat, n), {"p": Simplex(n), "n": Nat}, multinomial_logpdf
2150+
Multinomial,
2151+
Vector(Nat, n),
2152+
{"p": Simplex(n), "n": Nat},
2153+
lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
21612154
)
21622155

21632156
@pytest.mark.parametrize(
@@ -2187,106 +2180,36 @@ def test_multinomial_random(self, p, size, n):
21872180

21882181
assert m.eval().shape == size + p.shape
21892182

2190-
def test_multinomial_vec(self):
2191-
vals = np.array([[2, 4, 4], [3, 3, 4]])
2192-
p = np.array([0.2, 0.3, 0.5])
2193-
n = 10
2194-
2195-
with Model() as model_single:
2196-
Multinomial("m", n=n, p=p)
2197-
2198-
with Model() as model_many:
2199-
Multinomial("m", n=n, p=p, size=2)
2183+
@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
2184+
@pytest.mark.parametrize(
2185+
"p",
2186+
[
2187+
([0.2, 0.3, 0.5]),
2188+
([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]),
2189+
(np.abs(np.random.randn(2, 2, 4))),
2190+
],
2191+
)
2192+
@pytest.mark.parametrize("size", [1, 2, (2, 3)])
2193+
def test_multinomial_vectorized(self, n, p, size):
2194+
n = intX(np.array(n))
2195+
p = floatX(np.array(p))
2196+
p /= p.sum(axis=-1, keepdims=True)
22002197

2201-
assert_almost_equal(
2202-
scipy.stats.multinomial.logpmf(vals, n, p),
2203-
np.asarray([model_single.fastlogp({"m": val}) for val in vals]),
2204-
decimal=4,
2205-
)
2198+
mn = pm.Multinomial.dist(n=n, p=p, size=size)
2199+
vals = mn.eval()
22062200

22072201
assert_almost_equal(
22082202
scipy.stats.multinomial.logpmf(vals, n, p),
2209-
logp(model_many.m, vals).eval().squeeze(),
2203+
pm.logp(mn, vals).eval(),
22102204
decimal=4,
2205+
err_msg=f"vals={vals}",
22112206
)
22122207

2213-
assert_almost_equal(
2214-
sum(model_single.fastlogp({"m": val}) for val in vals),
2215-
model_many.fastlogp({"m": vals}),
2216-
decimal=4,
2217-
)
2218-
2219-
def test_multinomial_vec_1d_n(self):
2220-
vals = np.array([[2, 4, 4], [4, 3, 4]])
2221-
p = np.array([0.2, 0.3, 0.5])
2222-
ns = np.array([10, 11])
2223-
2224-
with Model() as model:
2225-
Multinomial("m", n=ns, p=p)
2226-
2227-
assert_almost_equal(
2228-
sum(multinomial_logpdf(val, n, p) for val, n in zip(vals, ns)),
2229-
model.fastlogp({"m": vals}),
2230-
decimal=4,
2231-
)
2232-
2233-
def test_multinomial_vec_1d_n_2d_p(self):
2234-
vals = np.array([[2, 4, 4], [4, 3, 4]])
2235-
ps = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]])
2236-
ns = np.array([10, 11])
2237-
2238-
with Model() as model:
2239-
Multinomial("m", n=ns, p=ps)
2240-
2241-
assert_almost_equal(
2242-
sum(multinomial_logpdf(val, n, p) for val, n, p in zip(vals, ns, ps)),
2243-
model.fastlogp({"m": vals}),
2244-
decimal=4,
2245-
)
2246-
2247-
def test_multinomial_vec_2d_p(self):
2248-
vals = np.array([[2, 4, 4], [3, 3, 4]])
2249-
ps = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]])
2250-
n = 10
2251-
2252-
with Model() as model:
2253-
Multinomial("m", n=n, p=ps)
2254-
2255-
assert_almost_equal(
2256-
sum(multinomial_logpdf(val, n, p) for val, p in zip(vals, ps)),
2257-
model.fastlogp({"m": vals}),
2258-
decimal=4,
2259-
)
2260-
2261-
def test_batch_multinomial(self):
2262-
n = 10
2263-
vals = intX(np.zeros((4, 5, 3)))
2264-
p = floatX(np.zeros_like(vals))
2265-
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None]
2266-
np.put_along_axis(vals, inds, n, axis=-1)
2267-
np.put_along_axis(p, inds, 1, axis=-1)
2268-
2269-
dist = Multinomial.dist(n=n, p=p)
2270-
logp_mn = at.exp(pm.logp(dist, vals)).eval()
2271-
assert_almost_equal(
2272-
logp_mn,
2273-
np.ones(vals.shape[:-1]),
2274-
decimal=select_by_precision(float64=6, float32=3),
2275-
)
2276-
2277-
dist = Multinomial.dist(n=n, p=p, size=2)
2278-
sample = dist.eval()
2279-
assert_allclose(sample, np.stack([vals, vals], axis=0))
2280-
22812208
def test_multinomial_zero_probs(self):
22822209
# test multinomial accepts 0 probabilities / observations:
2283-
value = aesara.shared(np.array([0, 0, 100], dtype=int))
2284-
logp = pm.Multinomial.logp(value=value, n=100, p=at.constant([0.0, 0.0, 1.0]))
2285-
logp_fn = aesara.function(inputs=[], outputs=logp)
2286-
assert logp_fn() >= 0
2287-
2288-
value.set_value(np.array([50, 50, 0], dtype=int))
2289-
assert np.isneginf(logp_fn())
2210+
mn = pm.Multinomial.dist(n=100, p=[0.0, 0.0, 1.0])
2211+
assert pm.logp(mn, np.array([0, 0, 100])).eval() >= 0
2212+
assert pm.logp(mn, np.array([50, 50, 0])).eval() == -np.inf
22902213

22912214
@pytest.mark.parametrize("n", [2, 3])
22922215
def test_dirichlet_multinomial(self, n):

0 commit comments

Comments
 (0)