Skip to content

Commit b12991c

Browse files
Re-enable v4 xfails in pymc3.distributions.dist_math
1 parent 26cd650 commit b12991c

File tree

3 files changed

+18
-26
lines changed

3 files changed

+18
-26
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ jobs:
7676
pymc3/tests/test_updates.py
7777
7878
- |
79-
pymc3/tests/test_dist_math.py
8079
pymc3/tests/test_distributions.py
8180
pymc3/tests/test_distributions_random.py
8281
pymc3/tests/test_examples.py

pymc3/distributions/dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def MvNormalLogp():
276276
n, k = delta.shape
277277
n, k = f(n), f(k)
278278
chol_cov = cholesky(cov)
279-
diag = aet.nlinalg.diag(chol_cov)
279+
diag = aet.diag(chol_cov)
280280
ok = aet.all(diag > 0)
281281

282282
chol_cov = aet.switch(ok, chol_cov, aet.fill(chol_cov, 1))
@@ -296,7 +296,7 @@ def dlogp(inputs, gradients):
296296
n, k = delta.shape
297297

298298
chol_cov = cholesky(cov)
299-
diag = aet.nlinalg.diag(chol_cov)
299+
diag = aet.diag(chol_cov)
300300
ok = aet.all(diag > 0)
301301

302302
chol_cov = aet.switch(ok, chol_cov, aet.fill(chol_cov, 1))

pymc3/tests/test_dist_math.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy.testing as npt
2020
import pytest
2121

22+
from aesara.tensor.random.basic import multinomial
2223
from scipy import interpolate, stats
2324

2425
import pymc3 as pm
@@ -91,16 +92,13 @@ def test_alltrue_shape():
9192

9293

9394
class MultinomialA(Discrete):
94-
def __init__(self, n, p, *args, **kwargs):
95-
super().__init__(*args, **kwargs)
95+
rv_op = multinomial
9696

97-
self.n = n
98-
self.p = p
99-
100-
def logp(self, value):
101-
n = self.n
102-
p = self.p
97+
@classmethod
98+
def dist(cls, n, p, *args, **kwargs):
99+
return super().dist([n, p], **kwargs)
103100

101+
def logp(value, n, p):
104102
return bound(
105103
factln(n) - factln(value).sum() + (value * aet.log(p)).sum(),
106104
value >= 0,
@@ -112,16 +110,13 @@ def logp(self, value):
112110

113111

114112
class MultinomialB(Discrete):
115-
def __init__(self, n, p, *args, **kwargs):
116-
super().__init__(*args, **kwargs)
117-
118-
self.n = n
119-
self.p = p
113+
rv_op = multinomial
120114

121-
def logp(self, value):
122-
n = self.n
123-
p = self.p
115+
@classmethod
116+
def dist(cls, n, p, *args, **kwargs):
117+
return super().dist([n, p], **kwargs)
124118

119+
def logp(value, n, p):
125120
return bound(
126121
factln(n) - factln(value).sum() + (value * aet.log(p)).sum(),
127122
aet.all(value >= 0),
@@ -132,26 +127,24 @@ def logp(self, value):
132127
)
133128

134129

135-
@pytest.mark.xfail(reason="This test relies on the deprecated Distribution interface")
136130
def test_multinomial_bound():
137131

138132
x = np.array([1, 5])
139133
n = x.sum()
140134

141135
with pm.Model() as modelA:
142-
p_a = pm.Dirichlet("p", floatX(np.ones(2)), shape=(2,))
136+
p_a = pm.Dirichlet("p", floatX(np.ones(2)))
143137
MultinomialA("x", n, p_a, observed=x)
144138

145139
with pm.Model() as modelB:
146-
p_b = pm.Dirichlet("p", floatX(np.ones(2)), shape=(2,))
140+
p_b = pm.Dirichlet("p", floatX(np.ones(2)))
147141
MultinomialB("x", n, p_b, observed=x)
148142

149143
assert np.isclose(
150144
modelA.logp({"p_stickbreaking__": [0]}), modelB.logp({"p_stickbreaking__": [0]})
151145
)
152146

153147

154-
@pytest.mark.xfail(reason="MvNormal not implemented")
155148
class TestMvNormalLogp:
156149
def test_logp(self):
157150
np.random.seed(42)
@@ -192,11 +185,10 @@ def func(chol_vec, delta):
192185
delta_val = floatX(np.random.randn(5, 2))
193186
verify_grad(func, [chol_vec_val, delta_val])
194187

195-
@pytest.mark.skip(reason="Fix in aesara not released yet: Theano#5908")
196188
@aesara.config.change_flags(compute_test_value="ignore")
197189
def test_hessian(self):
198190
chol_vec = aet.vector("chol_vec")
199-
chol_vec.tag.test_value = np.array([0.1, 2, 3])
191+
chol_vec.tag.test_value = floatX(np.array([0.1, 2, 3]))
200192
chol = aet.stack(
201193
[
202194
aet.stack([aet.exp(0.1 * chol_vec[0]), 0]),
@@ -205,9 +197,10 @@ def test_hessian(self):
205197
)
206198
cov = aet.dot(chol, chol.T)
207199
delta = aet.matrix("delta")
208-
delta.tag.test_value = np.ones((5, 2))
200+
delta.tag.test_value = floatX(np.ones((5, 2)))
209201
logp = MvNormalLogp()(cov, delta)
210202
g_cov, g_delta = aet.grad(logp, [cov, delta])
203+
# TODO: What's the test? Something needs to be asserted.
211204
aet.grad(g_delta.sum() + g_cov.sum(), [delta, cov])
212205

213206

0 commit comments

Comments
 (0)