@@ -48,7 +48,7 @@ def polyagamma_cdf(*args, **kwargs):
48
48
from aesara .tensor .random .op import RandomVariable
49
49
from aesara .tensor .var import TensorVariable
50
50
from numpy import array , inf , log
51
- from numpy .testing import assert_allclose , assert_almost_equal , assert_equal
51
+ from numpy .testing import assert_almost_equal , assert_equal
52
52
from scipy import integrate
53
53
from scipy .special import erf , gammaln , logit
54
54
@@ -327,16 +327,6 @@ def f3(a, b, c):
327
327
raise ValueError ("Dont know how to integrate shape: " + str (shape ))
328
328
329
329
330
- def multinomial_logpdf (value , n , p ):
331
- if value .sum () == n and (0 <= value ).all () and (value <= n ).all ():
332
- logpdf = scipy .special .gammaln (n + 1 )
333
- logpdf -= scipy .special .gammaln (value + 1 ).sum ()
334
- logpdf += logpow (p , value ).sum ()
335
- return logpdf
336
- else :
337
- return - inf
338
-
339
-
340
330
def _dirichlet_multinomial_logpmf (value , n , a ):
341
331
if value .sum () == n and (0 <= value ).all () and (value <= n ).all ():
342
332
sum_a = a .sum ()
@@ -2157,7 +2147,10 @@ def test_dirichlet_2D(self):
2157
2147
@pytest .mark .parametrize ("n" , [2 , 3 ])
2158
2148
def test_multinomial (self , n ):
2159
2149
self .check_logp (
2160
- Multinomial , Vector (Nat , n ), {"p" : Simplex (n ), "n" : Nat }, multinomial_logpdf
2150
+ Multinomial ,
2151
+ Vector (Nat , n ),
2152
+ {"p" : Simplex (n ), "n" : Nat },
2153
+ lambda value , n , p : scipy .stats .multinomial .logpmf (value , n , p ),
2161
2154
)
2162
2155
2163
2156
@pytest .mark .parametrize (
@@ -2187,106 +2180,36 @@ def test_multinomial_random(self, p, size, n):
2187
2180
2188
2181
assert m .eval ().shape == size + p .shape
2189
2182
2190
- def test_multinomial_vec (self ):
2191
- vals = np .array ([[2 , 4 , 4 ], [3 , 3 , 4 ]])
2192
- p = np .array ([0.2 , 0.3 , 0.5 ])
2193
- n = 10
2194
-
2195
- with Model () as model_single :
2196
- Multinomial ("m" , n = n , p = p )
2197
-
2198
- with Model () as model_many :
2199
- Multinomial ("m" , n = n , p = p , size = 2 )
2183
+ @pytest .mark .parametrize ("n" , [(10 ), ([10 , 11 ]), ([[5 , 6 ], [10 , 11 ]])])
2184
+ @pytest .mark .parametrize (
2185
+ "p" ,
2186
+ [
2187
+ ([0.2 , 0.3 , 0.5 ]),
2188
+ ([[0.2 , 0.3 , 0.5 ], [0.9 , 0.09 , 0.01 ]]),
2189
+ (np .abs (np .random .randn (2 , 2 , 4 ))),
2190
+ ],
2191
+ )
2192
+ @pytest .mark .parametrize ("size" , [1 , 2 , (2 , 3 )])
2193
+ def test_multinomial_vectorized (self , n , p , size ):
2194
+ n = intX (np .array (n ))
2195
+ p = floatX (np .array (p ))
2196
+ p /= p .sum (axis = - 1 , keepdims = True )
2200
2197
2201
- assert_almost_equal (
2202
- scipy .stats .multinomial .logpmf (vals , n , p ),
2203
- np .asarray ([model_single .fastlogp ({"m" : val }) for val in vals ]),
2204
- decimal = 4 ,
2205
- )
2198
+ mn = pm .Multinomial .dist (n = n , p = p , size = size )
2199
+ vals = mn .eval ()
2206
2200
2207
2201
assert_almost_equal (
2208
2202
scipy .stats .multinomial .logpmf (vals , n , p ),
2209
- logp (model_many . m , vals ).eval (). squeeze (),
2203
+ pm . logp (mn , vals ).eval (),
2210
2204
decimal = 4 ,
2205
+ err_msg = f"vals={ vals } " ,
2211
2206
)
2212
2207
2213
- assert_almost_equal (
2214
- sum (model_single .fastlogp ({"m" : val }) for val in vals ),
2215
- model_many .fastlogp ({"m" : vals }),
2216
- decimal = 4 ,
2217
- )
2218
-
2219
- def test_multinomial_vec_1d_n (self ):
2220
- vals = np .array ([[2 , 4 , 4 ], [4 , 3 , 4 ]])
2221
- p = np .array ([0.2 , 0.3 , 0.5 ])
2222
- ns = np .array ([10 , 11 ])
2223
-
2224
- with Model () as model :
2225
- Multinomial ("m" , n = ns , p = p )
2226
-
2227
- assert_almost_equal (
2228
- sum (multinomial_logpdf (val , n , p ) for val , n in zip (vals , ns )),
2229
- model .fastlogp ({"m" : vals }),
2230
- decimal = 4 ,
2231
- )
2232
-
2233
- def test_multinomial_vec_1d_n_2d_p (self ):
2234
- vals = np .array ([[2 , 4 , 4 ], [4 , 3 , 4 ]])
2235
- ps = np .array ([[0.2 , 0.3 , 0.5 ], [0.9 , 0.09 , 0.01 ]])
2236
- ns = np .array ([10 , 11 ])
2237
-
2238
- with Model () as model :
2239
- Multinomial ("m" , n = ns , p = ps )
2240
-
2241
- assert_almost_equal (
2242
- sum (multinomial_logpdf (val , n , p ) for val , n , p in zip (vals , ns , ps )),
2243
- model .fastlogp ({"m" : vals }),
2244
- decimal = 4 ,
2245
- )
2246
-
2247
- def test_multinomial_vec_2d_p (self ):
2248
- vals = np .array ([[2 , 4 , 4 ], [3 , 3 , 4 ]])
2249
- ps = np .array ([[0.2 , 0.3 , 0.5 ], [0.3 , 0.3 , 0.4 ]])
2250
- n = 10
2251
-
2252
- with Model () as model :
2253
- Multinomial ("m" , n = n , p = ps )
2254
-
2255
- assert_almost_equal (
2256
- sum (multinomial_logpdf (val , n , p ) for val , p in zip (vals , ps )),
2257
- model .fastlogp ({"m" : vals }),
2258
- decimal = 4 ,
2259
- )
2260
-
2261
- def test_batch_multinomial (self ):
2262
- n = 10
2263
- vals = intX (np .zeros ((4 , 5 , 3 )))
2264
- p = floatX (np .zeros_like (vals ))
2265
- inds = np .random .randint (vals .shape [- 1 ], size = vals .shape [:- 1 ])[..., None ]
2266
- np .put_along_axis (vals , inds , n , axis = - 1 )
2267
- np .put_along_axis (p , inds , 1 , axis = - 1 )
2268
-
2269
- dist = Multinomial .dist (n = n , p = p )
2270
- logp_mn = at .exp (pm .logp (dist , vals )).eval ()
2271
- assert_almost_equal (
2272
- logp_mn ,
2273
- np .ones (vals .shape [:- 1 ]),
2274
- decimal = select_by_precision (float64 = 6 , float32 = 3 ),
2275
- )
2276
-
2277
- dist = Multinomial .dist (n = n , p = p , size = 2 )
2278
- sample = dist .eval ()
2279
- assert_allclose (sample , np .stack ([vals , vals ], axis = 0 ))
2280
-
2281
2208
def test_multinomial_zero_probs (self ):
2282
2209
# test multinomial accepts 0 probabilities / observations:
2283
- value = aesara .shared (np .array ([0 , 0 , 100 ], dtype = int ))
2284
- logp = pm .Multinomial .logp (value = value , n = 100 , p = at .constant ([0.0 , 0.0 , 1.0 ]))
2285
- logp_fn = aesara .function (inputs = [], outputs = logp )
2286
- assert logp_fn () >= 0
2287
-
2288
- value .set_value (np .array ([50 , 50 , 0 ], dtype = int ))
2289
- assert np .isneginf (logp_fn ())
2210
+ mn = pm .Multinomial .dist (n = 100 , p = [0.0 , 0.0 , 1.0 ])
2211
+ assert pm .logp (mn , np .array ([0 , 0 , 100 ])).eval () >= 0
2212
+ assert pm .logp (mn , np .array ([50 , 50 , 0 ])).eval () == - np .inf
2290
2213
2291
2214
@pytest .mark .parametrize ("n" , [2 , 3 ])
2292
2215
def test_dirichlet_multinomial (self , n ):
0 commit comments