19
19
import numpy .testing as npt
20
20
import pytest
21
21
22
+ from aesara .tensor .random .basic import multinomial
22
23
from scipy import interpolate , stats
23
24
24
25
import pymc3 as pm
@@ -91,16 +92,13 @@ def test_alltrue_shape():
91
92
92
93
93
94
class MultinomialA (Discrete ):
94
- def __init__ (self , n , p , * args , ** kwargs ):
95
- super ().__init__ (* args , ** kwargs )
95
+ rv_op = multinomial
96
96
97
- self .n = n
98
- self .p = p
99
-
100
- def logp (self , value ):
101
- n = self .n
102
- p = self .p
97
+ @classmethod
98
+ def dist (cls , n , p , * args , ** kwargs ):
99
+ return super ().dist ([n , p ], ** kwargs )
103
100
101
+ def logp (value , n , p ):
104
102
return bound (
105
103
factln (n ) - factln (value ).sum () + (value * aet .log (p )).sum (),
106
104
value >= 0 ,
@@ -112,16 +110,13 @@ def logp(self, value):
112
110
113
111
114
112
class MultinomialB (Discrete ):
115
- def __init__ (self , n , p , * args , ** kwargs ):
116
- super ().__init__ (* args , ** kwargs )
117
-
118
- self .n = n
119
- self .p = p
113
+ rv_op = multinomial
120
114
121
- def logp ( self , value ):
122
- n = self . n
123
- p = self . p
115
+ @ classmethod
116
+ def dist ( cls , n , p , * args , ** kwargs ):
117
+ return super (). dist ([ n , p ], ** kwargs )
124
118
119
+ def logp (value , n , p ):
125
120
return bound (
126
121
factln (n ) - factln (value ).sum () + (value * aet .log (p )).sum (),
127
122
aet .all (value >= 0 ),
@@ -132,26 +127,24 @@ def logp(self, value):
132
127
)
133
128
134
129
135
- @pytest .mark .xfail (reason = "This test relies on the deprecated Distribution interface" )
136
130
def test_multinomial_bound ():
137
131
138
132
x = np .array ([1 , 5 ])
139
133
n = x .sum ()
140
134
141
135
with pm .Model () as modelA :
142
- p_a = pm .Dirichlet ("p" , floatX (np .ones (2 )), shape = ( 2 ,) )
136
+ p_a = pm .Dirichlet ("p" , floatX (np .ones (2 )))
143
137
MultinomialA ("x" , n , p_a , observed = x )
144
138
145
139
with pm .Model () as modelB :
146
- p_b = pm .Dirichlet ("p" , floatX (np .ones (2 )), shape = ( 2 ,) )
140
+ p_b = pm .Dirichlet ("p" , floatX (np .ones (2 )))
147
141
MultinomialB ("x" , n , p_b , observed = x )
148
142
149
143
assert np .isclose (
150
144
modelA .logp ({"p_stickbreaking__" : [0 ]}), modelB .logp ({"p_stickbreaking__" : [0 ]})
151
145
)
152
146
153
147
154
- @pytest .mark .xfail (reason = "MvNormal not implemented" )
155
148
class TestMvNormalLogp :
156
149
def test_logp (self ):
157
150
np .random .seed (42 )
@@ -192,11 +185,10 @@ def func(chol_vec, delta):
192
185
delta_val = floatX (np .random .randn (5 , 2 ))
193
186
verify_grad (func , [chol_vec_val , delta_val ])
194
187
195
- @pytest .mark .skip (reason = "Fix in aesara not released yet: Theano#5908" )
196
188
@aesara .config .change_flags (compute_test_value = "ignore" )
197
189
def test_hessian (self ):
198
190
chol_vec = aet .vector ("chol_vec" )
199
- chol_vec .tag .test_value = np .array ([0.1 , 2 , 3 ])
191
+ chol_vec .tag .test_value = floatX ( np .array ([0.1 , 2 , 3 ]) )
200
192
chol = aet .stack (
201
193
[
202
194
aet .stack ([aet .exp (0.1 * chol_vec [0 ]), 0 ]),
@@ -205,9 +197,10 @@ def test_hessian(self):
205
197
)
206
198
cov = aet .dot (chol , chol .T )
207
199
delta = aet .matrix ("delta" )
208
- delta .tag .test_value = np .ones ((5 , 2 ))
200
+ delta .tag .test_value = floatX ( np .ones ((5 , 2 ) ))
209
201
logp = MvNormalLogp ()(cov , delta )
210
202
g_cov , g_delta = aet .grad (logp , [cov , delta ])
203
+ # TODO: What's the test? Something needs to be asserted.
211
204
aet .grad (g_delta .sum () + g_cov .sum (), [delta , cov ])
212
205
213
206
0 commit comments