@@ -463,8 +463,12 @@ def discrete_weibull_logpmf(value, q, beta):
463
463
)
464
464
465
465
466
- def dirichlet_logpdf (value , a ):
467
- return floatX ((- betafn (a ) + logpow (value , a - 1 ).sum (- 1 )).sum ())
466
+ def _dirichlet_logpdf (value , a ):
467
+ # scipy.stats.dirichlet.logpdf suffers from numerical precision issues
468
+ return - betafn (a ) + logpow (value , a - 1 ).sum ()
469
+
470
+
471
+ dirichlet_logpdf = np .vectorize (_dirichlet_logpdf , signature = "(n),(n)->()" )
468
472
469
473
470
474
def categorical_logpdf (value , p ):
@@ -2101,32 +2105,34 @@ def test_lkj(self, x, eta, n, lp):
2101
2105
2102
2106
@pytest .mark .parametrize ("n" , [1 , 2 , 3 ])
2103
2107
def test_dirichlet (self , n ):
2104
- self .check_logp (Dirichlet , Simplex (n ), {"a" : Vector (Rplus , n )}, dirichlet_logpdf )
2105
-
2106
- @pytest .mark .parametrize ("dist_shape" , [(1 , 2 ), (2 , 4 , 3 )])
2107
- def test_dirichlet_with_batch_shapes (self , dist_shape ):
2108
- a = np .ones (dist_shape )
2109
- with pm .Model () as model :
2110
- d = pm .Dirichlet ("d" , a = a )
2111
-
2112
- # Generate sample points to test
2113
- d_value = d .tag .value_var
2114
- d_point = d .eval ().astype ("float64" )
2115
- d_point /= d_point .sum (axis = - 1 )[..., None ]
2116
-
2117
- if hasattr (d_value .tag , "transform" ):
2118
- d_point_trans = d_value .tag .transform .forward (
2119
- at .as_tensor (d_point ), * d .owner .inputs
2120
- ).eval ()
2121
- else :
2122
- d_point_trans = d_point
2108
+ self .check_logp (
2109
+ Dirichlet ,
2110
+ Simplex (n ),
2111
+ {"a" : Vector (Rplus , n )},
2112
+ dirichlet_logpdf ,
2113
+ )
2123
2114
2124
- pymc_res = logpt (d , d_point_trans , jacobian = False , sum = False ).eval ()
2125
- scipy_res = np .empty_like (pymc_res )
2126
- for idx in np .ndindex (a .shape [:- 1 ]):
2127
- scipy_res [idx ] = scipy .stats .dirichlet (a [idx ]).logpdf (d_point [idx ])
2115
+ @pytest .mark .parametrize (
2116
+ "a" ,
2117
+ [
2118
+ ([2 , 3 , 5 ]),
2119
+ ([[2 , 3 , 5 ], [9 , 19 , 3 ]]),
2120
+ (np .abs (np .random .randn (2 , 2 , 4 )) + 1 ),
2121
+ ],
2122
+ )
2123
+ @pytest .mark .parametrize ("size" , [2 , (1 , 2 ), (2 , 4 , 3 )])
2124
+ def test_dirichlet_vectorized (self , a , size ):
2125
+ a = floatX (np .array (a ))
2126
+
2127
+ dir = pm .Dirichlet .dist (a = a , size = size )
2128
+ vals = dir .eval ()
2128
2129
2129
- assert_almost_equal (pymc_res , scipy_res )
2130
+ assert_almost_equal (
2131
+ dirichlet_logpdf (vals , a ),
2132
+ pm .logp (dir , vals ).eval (),
2133
+ decimal = 4 ,
2134
+ err_msg = f"vals={ vals } " ,
2135
+ )
2130
2136
2131
2137
def test_dirichlet_shape (self ):
2132
2138
a = at .as_tensor_variable (np .r_ [1 , 2 ])
@@ -2136,14 +2142,6 @@ def test_dirichlet_shape(self):
2136
2142
with pytest .warns (DeprecationWarning ), aesara .change_flags (compute_test_value = "ignore" ):
2137
2143
dir_rv = Dirichlet .dist (at .vector ())
2138
2144
2139
- def test_dirichlet_2D (self ):
2140
- self .check_logp (
2141
- Dirichlet ,
2142
- MultiSimplex (2 , 2 ),
2143
- {"a" : Vector (Vector (Rplus , 2 ), 2 )},
2144
- dirichlet_logpdf ,
2145
- )
2146
-
2147
2145
@pytest .mark .parametrize ("n" , [2 , 3 ])
2148
2146
def test_multinomial (self , n ):
2149
2147
self .check_logp (
0 commit comments