Skip to content

Commit 4ed1042

Browse files
committed
Refactor dirichlet vectorized logp tests
1 parent 248131d commit 4ed1042

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

pymc/tests/test_distributions.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,12 @@ def discrete_weibull_logpmf(value, q, beta):
463463
)
464464

465465

466-
def dirichlet_logpdf(value, a):
467-
return floatX((-betafn(a) + logpow(value, a - 1).sum(-1)).sum())
466+
def _dirichlet_logpdf(value, a):
467+
# scipy.stats.dirichlet.logpdf suffers from numerical precision issues
468+
return -betafn(a) + logpow(value, a - 1).sum()
469+
470+
471+
dirichlet_logpdf = np.vectorize(_dirichlet_logpdf, signature="(n),(n)->()")
468472

469473

470474
def categorical_logpdf(value, p):
@@ -2101,32 +2105,34 @@ def test_lkj(self, x, eta, n, lp):
21012105

21022106
@pytest.mark.parametrize("n", [1, 2, 3])
21032107
def test_dirichlet(self, n):
2104-
self.check_logp(Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf)
2105-
2106-
@pytest.mark.parametrize("dist_shape", [(1, 2), (2, 4, 3)])
2107-
def test_dirichlet_with_batch_shapes(self, dist_shape):
2108-
a = np.ones(dist_shape)
2109-
with pm.Model() as model:
2110-
d = pm.Dirichlet("d", a=a)
2111-
2112-
# Generate sample points to test
2113-
d_value = d.tag.value_var
2114-
d_point = d.eval().astype("float64")
2115-
d_point /= d_point.sum(axis=-1)[..., None]
2116-
2117-
if hasattr(d_value.tag, "transform"):
2118-
d_point_trans = d_value.tag.transform.forward(
2119-
at.as_tensor(d_point), *d.owner.inputs
2120-
).eval()
2121-
else:
2122-
d_point_trans = d_point
2108+
self.check_logp(
2109+
Dirichlet,
2110+
Simplex(n),
2111+
{"a": Vector(Rplus, n)},
2112+
dirichlet_logpdf,
2113+
)
21232114

2124-
pymc_res = logpt(d, d_point_trans, jacobian=False, sum=False).eval()
2125-
scipy_res = np.empty_like(pymc_res)
2126-
for idx in np.ndindex(a.shape[:-1]):
2127-
scipy_res[idx] = scipy.stats.dirichlet(a[idx]).logpdf(d_point[idx])
2115+
@pytest.mark.parametrize(
2116+
"a",
2117+
[
2118+
([2, 3, 5]),
2119+
([[2, 3, 5], [9, 19, 3]]),
2120+
(np.abs(np.random.randn(2, 2, 4)) + 1),
2121+
],
2122+
)
2123+
@pytest.mark.parametrize("size", [2, (1, 2), (2, 4, 3)])
2124+
def test_dirichlet_vectorized(self, a, size):
2125+
a = floatX(np.array(a))
2126+
2127+
dir = pm.Dirichlet.dist(a=a, size=size)
2128+
vals = dir.eval()
21282129

2129-
assert_almost_equal(pymc_res, scipy_res)
2130+
assert_almost_equal(
2131+
dirichlet_logpdf(vals, a),
2132+
pm.logp(dir, vals).eval(),
2133+
decimal=4,
2134+
err_msg=f"vals={vals}",
2135+
)
21302136

21312137
def test_dirichlet_shape(self):
21322138
a = at.as_tensor_variable(np.r_[1, 2])
@@ -2136,14 +2142,6 @@ def test_dirichlet_shape(self):
21362142
with pytest.warns(DeprecationWarning), aesara.change_flags(compute_test_value="ignore"):
21372143
dir_rv = Dirichlet.dist(at.vector())
21382144

2139-
def test_dirichlet_2D(self):
2140-
self.check_logp(
2141-
Dirichlet,
2142-
MultiSimplex(2, 2),
2143-
{"a": Vector(Vector(Rplus, 2), 2)},
2144-
dirichlet_logpdf,
2145-
)
2146-
21472145
@pytest.mark.parametrize("n", [2, 3])
21482146
def test_multinomial(self, n):
21492147
self.check_logp(

0 commit comments

Comments
 (0)